From fafbb0682c44c02a86b7597b42cf4b407a030761 Mon Sep 17 00:00:00 2001 From: "todd.chen" Date: Thu, 29 Dec 2016 10:12:25 +0800 Subject: [PATCH 1/3] [SPARK-19018][SQL] ADD csv write charset param --- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 1 + .../sql/execution/datasources/csv/CSVFileFormat.scala | 6 +++--- .../spark/sql/execution/datasources/csv/CSVOptions.scala | 4 +++- .../spark/sql/execution/datasources/csv/CSVParser.scala | 7 ++++--- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 9c5660a3780ad..28d85d29338a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -573,6 +573,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type. * + *
  • `writeEncoding`(default `utf-8`) save dataFrame 2 csv by giving encoding
  • * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index b0feaeb84e9f4..b9caabfdd0306 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -154,7 +154,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) linesReader.map { line => - new String(line.getBytes, 0, line.getLength, csvOptions.charset) + new String(line.getBytes, 0, line.getLength, csvOptions.readCharSet) } } @@ -195,7 +195,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: CSVOptions, inputPaths: Seq[String]): Dataset[String] = { - if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { + if (Charset.forName(options.readCharSet) == StandardCharsets.UTF_8) { sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, @@ -204,7 +204,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { ).resolveRelation(checkFilesExist = false)) .select("value").as[String](Encoders.STRING) } else { - val charset = options.charset + val charset = options.readCharSet val rdd = sparkSession.sparkContext .hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(",")) .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 21e50307b5ab0..4da037de0602d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -71,7 +71,9 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive val delimiter = CSVTypeCast.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") - val charset = parameters.getOrElse("encoding", + val readCharSet = parameters.getOrElse("encoding", + parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) + val writeCharSet = parameters.getOrElse("writeEncoding", parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) val quote = getChar("quote", '\"') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index 6239508ec9422..5b68724245f24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.datasources.csv -import java.io.{CharArrayWriter, OutputStream, StringReader} -import java.nio.charset.StandardCharsets +import java.io.OutputStream +import java.nio.charset.Charset import com.univocity.parsers.csv._ @@ -71,6 +71,7 @@ private[csv] class LineCsvWriter( output: OutputStream) extends Logging { private val writerSettings = new CsvWriterSettings private val format = writerSettings.getFormat + private val writerCharset = Charset.forName(params.writeCharSet) format.setDelimiter(params.delimiter) format.setQuote(params.quote) @@ -84,7 +85,7 @@ private[csv] class LineCsvWriter( writerSettings.setHeaders(headers: _*) writerSettings.setQuoteEscapingEnabled(params.escapeQuotes) - private val writer = new CsvWriter(output, StandardCharsets.UTF_8, writerSettings) + private val writer = new CsvWriter(output, writerCharset, writerSettings) def writeRow(row: Seq[String], includeHeader: Boolean): Unit = { if (includeHeader) { From 9b13f0041a61b738eed18f20e40dd0ccd1ef170d Mon Sep 17 00:00:00 2001 From: "todd.chen" Date: Sat, 31 Dec 2016 10:17:23 +0800 Subject: [PATCH 2/3] [SPARK-19018][SQL] add doc and unit test,refine csv writer settings --- python/pyspark/sql/readwriter.py | 6 ++++-- .../org/apache/spark/sql/DataFrameWriter.scala | 2 +- .../datasources/csv/CSVFileFormat.scala | 6 +++--- .../execution/datasources/csv/CSVOptions.scala | 4 +--- .../execution/datasources/csv/CSVParser.scala | 2 +- .../execution/datasources/csv/CSVSuite.scala | 18 ++++++++++++++++++ 6 files changed, 28 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b0c51b1e9992e..504d2c0fbd409 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -659,7 +659,7 @@ def text(self, path, compression=None): self._jwrite.text(path) @since(2.0) - def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, + def csv(self, path, mode=None, compression=None, sep=None, encoding=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, timestampFormat=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. @@ -677,6 +677,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No snappy and deflate). :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. + :param encoding: sets writer CSV files by the given encoding type. If None is set, + it uses the default value, ``UTF-8``. :param quote: sets the single character used for escaping quoted values where the separator can be part of the value. If None is set, it uses the default value, ``"``. If you would like to turn off quotations, you need to set an @@ -705,7 +707,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) - self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, + self._set_opts(compression=compression, sep=sep, encoding=encoding, quote=quote, escape=escape, header=header, nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll, dateFormat=dateFormat, timestampFormat=timestampFormat) self._jwrite.csv(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 28d85d29338a9..4ff5824d90d64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -573,7 +573,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type. * - *
  • `writeEncoding`(default `utf-8`) save dataFrame 2 csv by giving encoding
  • + *
  • `encoding`(default `utf-8`) save dataFrame 2 csv by giving encoding
  • * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index b9caabfdd0306..b0feaeb84e9f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -154,7 +154,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) linesReader.map { line => - new String(line.getBytes, 0, line.getLength, csvOptions.readCharSet) + new String(line.getBytes, 0, line.getLength, csvOptions.charset) } } @@ -195,7 +195,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: CSVOptions, inputPaths: Seq[String]): Dataset[String] = { - if (Charset.forName(options.readCharSet) == StandardCharsets.UTF_8) { + if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, @@ -204,7 +204,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { ).resolveRelation(checkFilesExist = false)) .select("value").as[String](Encoders.STRING) } else { - val charset = options.readCharSet + val charset = options.charset val rdd = sparkSession.sparkContext .hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(",")) .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 4da037de0602d..21e50307b5ab0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -71,9 +71,7 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive val delimiter = CSVTypeCast.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") - val readCharSet = parameters.getOrElse("encoding", - parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) - val writeCharSet = parameters.getOrElse("writeEncoding", + val charset = parameters.getOrElse("encoding", parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) val quote = getChar("quote", '\"') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index 5b68724245f24..025de5b292764 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -71,7 +71,7 @@ private[csv] class LineCsvWriter( output: OutputStream) extends Logging { private val writerSettings = new CsvWriterSettings private val format = writerSettings.getFormat - private val writerCharset = Charset.forName(params.writeCharSet) + private val writerCharset = Charset.forName(params.charset) format.setDelimiter(params.delimiter) format.setQuote(params.quote) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 491ff72337a81..d6e08d4019dbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ +//noinspection ScalaStyle class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { import testImplicits._ @@ -905,4 +906,21 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(df, Row(1, null)) } } + + test("save data with gb18030") { + withTempPath{ path => + Seq(("1", "中文")) + .toDF("num", "lanaguage") + .write + .option("encoding", "GB18030") + .option("header", "true") + .csv(path.getAbsolutePath) + val df = spark.read + .option("header", "true") + .option("encoding", "GB18030") + .csv(path.getAbsolutePath) + + checkAnswer(df, Row("1", "中文")) + } + } } From 724376945bb0ca15245993d9adeadb7dce2d3d5d Mon Sep 17 00:00:00 2001 From: "todd.chen" Date: Sat, 31 Dec 2016 21:48:02 +0800 Subject: [PATCH 3/3] [SPARK-19018][SQL] refine code style --- python/pyspark/sql/readwriter.py | 12 ++++++------ .../org/apache/spark/sql/DataFrameWriter.scala | 3 ++- .../execution/datasources/csv/CSVSuite.scala | 17 +++++++++-------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 504d2c0fbd409..874eed436c9e1 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -659,9 +659,9 @@ def text(self, path, compression=None): self._jwrite.text(path) @since(2.0) - def csv(self, path, mode=None, compression=None, sep=None, encoding=None, quote=None, escape=None, + def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, - timestampFormat=None): + timestampFormat=None, encoding=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -677,8 +677,6 @@ def csv(self, path, mode=None, compression=None, sep=None, encoding=None, quote= snappy and deflate). :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. - :param encoding: sets writer CSV files by the given encoding type. If None is set, - it uses the default value, ``UTF-8``. :param quote: sets the single character used for escaping quoted values where the separator can be part of the value. If None is set, it uses the default value, ``"``. If you would like to turn off quotations, you need to set an @@ -703,13 +701,15 @@ def csv(self, path, mode=None, compression=None, sep=None, encoding=None, quote= formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + :param encoding: encodes the CSV files by the given encoding type. If None is set, + it uses the default value, ``UTF-8``. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) - self._set_opts(compression=compression, sep=sep, encoding=encoding, quote=quote, escape=escape, header=header, + self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll, - dateFormat=dateFormat, timestampFormat=timestampFormat) + dateFormat=dateFormat, timestampFormat=timestampFormat, encoding=encoding) self._jwrite.csv(path) @since(1.5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 4ff5824d90d64..35d75dd72d5f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -572,8 +572,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + *
  • `encoding` (default `UTF-8`): encodes the CSV files by the given encoding + * type.
  • * - *
  • `encoding`(default `utf-8`) save dataFrame 2 csv by giving encoding
  • * * @since 2.0.0 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d6e08d4019dbc..fe962cc90e794 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -//noinspection ScalaStyle class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { import testImplicits._ @@ -908,19 +907,21 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("save data with gb18030") { - withTempPath{ path => - Seq(("1", "中文")) - .toDF("num", "lanaguage") - .write - .option("encoding", "GB18030") + withTempPath { path => + // scalastyle:off + val df = Seq(("1", "中文")).toDF("num", "lanaguage") + // scalastyle:on + df.write .option("header", "true") + .option("encoding", "GB18030") .csv(path.getAbsolutePath) - val df = spark.read + + val readBack = spark.read .option("header", "true") .option("encoding", "GB18030") .csv(path.getAbsolutePath) - checkAnswer(df, Row("1", "中文")) + checkAnswer(df, readBack) } } }