From ce1a4e4d4a45a20903f4d510f8f743f16bbf3342 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 18 Mar 2019 20:19:49 +0800 Subject: [PATCH 1/8] csv data source V2 --- ...pache.spark.sql.sources.DataSourceRegister | 2 +- .../spark/sql/execution/command/tables.scala | 4 +- .../datasources/csv/CSVFileFormat.scala | 2 +- .../v2/PartitionReaderFromIterator.scala | 37 ++++++++ .../datasources/v2/TextBasedFileScan.scala | 45 ++++++++++ .../datasources/v2/csv/CSVDataSourceV2.scala | 57 +++++++++++++ .../v2/csv/CSVPartitionReaderFactory.scala | 72 ++++++++++++++++ .../datasources/v2/csv/CSVScan.scala | 84 +++++++++++++++++++ .../datasources/v2/csv/CSVScanBuilder.scala | 37 ++++++++ .../datasources/v2/csv/CSVTable.scala | 52 ++++++++++++ .../datasources/v2/csv/CSVWriteBuilder.scala | 65 ++++++++++++++ .../spark/sql/FileBasedDataSourceSuite.scala | 59 ++++++------- .../execution/datasources/csv/CSVSuite.scala | 10 +-- .../sql/sources/v2/DataSourceV2Suite.scala | 4 +- .../sql/test/DataFrameReaderWriterSuite.scala | 2 +- 15 files changed, 488 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderFromIterator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 7cdfddc5e7aa6..b686187552584 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,4 +1,4 @@ -org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +org.apache.spark.sql.execution.datasources.v2.csv.CSVDataSourceV2 org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider org.apache.spark.sql.execution.datasources.json.JsonFileFormat org.apache.spark.sql.execution.datasources.noop.NoopDataSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 8b70e336c14bb..08d6dc62e3542 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -37,9 +37,9 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} -import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.v2.csv.CSVDataSourceV2 import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -238,7 +238,7 @@ case class AlterTableAddColumnsCommand( // TextFileFormat only default to one column "value" // Hive type is already considered as hive serde table, so the logic will not // come in here. - case _: JsonFileFormat | _: CSVFileFormat | _: ParquetFileFormat | _: OrcDataSourceV2 => + case _: JsonFileFormat | _: CSVDataSourceV2 | _: ParquetFileFormat | _: OrcDataSourceV2 => case s if s.getClass.getCanonicalName.endsWith("OrcFileFormat") => case s => throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index d08a54cc9b1f0..4eceb86b44542 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -163,7 +163,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } -private[csv] class CsvOutputWriter( +class CsvOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderFromIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderFromIterator.scala new file mode 100644 index 0000000000000..f6cd691c5434f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderFromIterator.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.sources.v2.reader.PartitionReader + +class PartitionReaderFromIterator[InternalRow]( + iter: Iterator[InternalRow]) extends PartitionReader[InternalRow] { + private var nextValue: InternalRow = _ + + override def next(): Boolean = { + if (iter.hasNext) { + nextValue = iter.next() + true + } else { + false + } + } + + override def get(): InternalRow = nextValue + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala new file mode 100644 index 0000000000000..8d9cc68417ef6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +abstract class TextBasedFileScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + readSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScan(sparkSession, fileIndex, readSchema, options) { + private var codecFactory: CompressionCodecFactory = _ + + override def isSplitable(path: Path): Boolean = { + if (codecFactory == null) { + codecFactory = new CompressionCodecFactory( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + } + val codec = codecFactory.getCodec(path) + codec == null || codec.isInstanceOf[SplittableCompressionCodec] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala new file mode 100644 index 0000000000000..56f1169956d67 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.csv + +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.sources.v2.Table +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class CSVDataSourceV2 extends FileDataSourceV2 { + + override def fallBackFileFormat: Class[_ <: FileFormat] = classOf[CSVFileFormat] + + override def shortName(): String = "csv" + + private def getTableName(paths: Seq[String]): String = { + shortName() + ":" + paths.mkString(";") + } + + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(paths) + CSVTable(tableName, sparkSession, options, paths, None) + } + + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(paths) + CSVTable(tableName, sparkSession, options, paths, Some(schema)) + } +} + +object CSVDataSourceV2 { + def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _ => false + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala new file mode 100644 index 0000000000000..312ee48aaa618 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.csv + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.csv.CSVDataSource +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.PartitionReader +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +/** + * A factory used to create CSV readers. + * + * @param sqlConf SQL configuration. + * @param broadcastedConf Broadcast serializable Hadoop Configuration. + * @param dataSchema Schema of CSV files. + * @param partitionSchema Schema of partitions. + * @param readSchema Required schema in the batch scan. + */ +case class CSVPartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + partitionSchema: StructType, + readSchema: StructType, + parsedOptions: CSVOptions) extends FilePartitionReaderFactory { + private val columnPruning = sqlConf.csvColumnPruning + private val readDataSchema = + getReadDataSchema(readSchema, partitionSchema, sqlConf.caseSensitiveAnalysis) + + override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { + val conf = broadcastedConf.value.value + + val parser = new UnivocityParser( + StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + StructType(readDataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + parsedOptions) + val schema = if (columnPruning) readDataSchema else dataSchema + val isStartOfFile = file.start == 0 + val headerChecker = new CSVHeaderChecker( + schema, parsedOptions, source = s"CSV file: ${file.filePath}", isStartOfFile) + val iter = CSVDataSource(parsedOptions).readFile( + conf, + file, + parser, + headerChecker, + readSchema) + val fileReader = new PartitionReaderFromIterator[InternalRow](iter) + new PartitionReaderWithPartitionValues(fileReader, readDataSchema, + partitionSchema, file.partitionValues) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala new file mode 100644 index 0000000000000..a5b1fe019041e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.csv + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.csv.CSVDataSource +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan +import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +case class CSVScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readSchema: StructType, + options: CaseInsensitiveStringMap) + extends TextBasedFileScan(sparkSession, fileIndex, readSchema, options) { + + private val optionsAsScala = options.asScala.toMap + private lazy val parsedOptions: CSVOptions = new CSVOptions( + optionsAsScala, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + + override def isSplitable(path: Path): Boolean = { + CSVDataSource(parsedOptions).isSplitable && super.isSplitable(path) + } + + override def createReaderFactory(): PartitionReaderFactory = { + // Check a field requirement for corrupt records here to throw an exception in a driver side + ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) + + if (readSchema.length == 1 && + readSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { + throw new AnalysisException( + "Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the\n" + + "referenced columns only include the internal corrupt record column\n" + + s"(named _corrupt_record by default). For example:\n" + + "spark.read.schema(schema).csv(file).filter($\"_corrupt_record\".isNotNull).count()\n" + + "and spark.read.schema(schema).csv(file).select(\"_corrupt_record\").show().\n" + + "Instead, you can cache or save the parsed results and then send the same query.\n" + + "For example, val df = spark.read.schema(schema).csv(file).cache() and then\n" + + "df.filter($\"_corrupt_record\".isNotNull).count()." + ) + } + + val hadoopConf = + sparkSession.sessionState.newHadoopConfWithOptions(optionsAsScala) + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration(hadoopConf)) + CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, + dataSchema, fileIndex.partitionSchema, readSchema, parsedOptions) + } + + override def supportsDataType(dataType: DataType): Boolean = { + CSVDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "CSV" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala new file mode 100644 index 0000000000000..dbb3c03ca9811 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2.csv + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.sources.v2.reader.Scan +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class CSVScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) extends FileScanBuilder(schema) { + + override def build(): Scan = { + CSVScan(sparkSession, fileIndex, dataSchema, readSchema, options) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala new file mode 100644 index 0000000000000..bf4b8ba868f23 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.csv + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.FileStatus + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.execution.datasources.csv.CSVDataSource +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class CSVTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + override def newScanBuilder(options: CaseInsensitiveStringMap): CSVScanBuilder = + CSVScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + val parsedOptions = new CSVOptions( + options.asScala.toMap, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) + + CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) + } + + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = + new CSVWriteBuilder(options, paths) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala new file mode 100644 index 0000000000000..70cda53f2461e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.csv + +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} + +import org.apache.spark.sql.catalyst.csv.CSVOptions +import org.apache.spark.sql.catalyst.util.CompressionCodecs +import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.execution.datasources.csv.CsvOutputWriter +import org.apache.spark.sql.execution.datasources.v2.FileWriteBuilder +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class CSVWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) + extends FileWriteBuilder(options, paths) { + override def prepareWrite( + sqlConf: SQLConf, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val conf = job.getConfiguration + val csvOptions = new CSVOptions( + options, + columnPruning = sqlConf.csvColumnPruning, + sqlConf.sessionLocalTimeZone) + csvOptions.compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(conf, codec) + } + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new CsvOutputWriter(path, dataSchema, context, csvOptions) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".csv" + CodecStreams.getCompressionExtension(context) + } + } + } + + override def supportsDataType(dataType: DataType): Boolean = { + CSVDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "CSV" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 58522f7b13769..a41dd43818873 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -329,27 +329,26 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { withTempDir { dir => val tempDir = new File(dir, "files").getCanonicalPath - Seq(true).foreach { useV1 => + Seq(true, false).foreach { useV1 => val useV1List = if (useV1) { - "orc" + "orc,csv" } else { "" } - def errorMessage(format: String, isWrite: Boolean): String = { - if (isWrite && (useV1 || format != "orc")) { - "cannot save interval data type into external storage." - } else { - s"$format data source does not support calendarinterval data type." - } + def validateErrorMessage(msg: String): Unit = { + val msg1 = "cannot save interval data type into external storage." + val msg2 = "data source does not support calendarinterval data type." + assert(msg.toLowerCase(Locale.ROOT).contains(msg1) || + msg.toLowerCase(Locale.ROOT).contains(msg2)) } withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) { // write path Seq("csv", "json", "parquet", "orc").foreach { format => - var msg = intercept[AnalysisException] { + val msg = intercept[AnalysisException] { sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) }.getMessage - assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, true))) + validateErrorMessage(msg) } // read path @@ -359,14 +358,14 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() }.getMessage - assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + validateErrorMessage(msg) msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() }.getMessage - assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + validateErrorMessage(msg) } } } @@ -374,9 +373,9 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { - Seq(true).foreach { useV1 => + Seq(true, false).foreach { useV1 => val useV1List = if (useV1) { - "orc" + "orc,csv" } else { "" } @@ -470,22 +469,24 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } test("SPARK-25237 compute correct input metrics in FileScanRDD") { - withTempPath { p => - val path = p.getAbsolutePath - spark.range(1000).repartition(1).write.csv(path) - val bytesReads = new mutable.ArrayBuffer[Long]() - val bytesReadListener = new SparkListener() { - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead + withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "csv") { + withTempPath { p => + val path = p.getAbsolutePath + spark.range(1000).repartition(1).write.csv(path) + val bytesReads = new mutable.ArrayBuffer[Long]() + val bytesReadListener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead + } + } + sparkContext.addSparkListener(bytesReadListener) + try { + spark.read.csv(path).limit(1).collect() + sparkContext.listenerBus.waitUntilEmpty(1000L) + assert(bytesReads.sum === 7860) + } finally { + sparkContext.removeSparkListener(bytesReadListener) } - } - sparkContext.addSparkListener(bytesReadListener) - try { - spark.read.csv(path).limit(1).collect() - sparkContext.listenerBus.waitUntilEmpty(1000L) - assert(bytesReads.sum === 7860) - } finally { - sparkContext.removeSparkListener(bytesReadListener) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d9e5d7af19671..e369596a716b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1343,15 +1343,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te .collect() }.getMessage assert(msg.contains("only include the internal corrupt record column")) - intercept[org.apache.spark.sql.catalyst.errors.TreeNodeException[_]] { - spark - .read - .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schema) - .csv(testFile(valueMalformedFile)) - .filter($"_corrupt_record".isNotNull) - .count() - } + // workaround val df = spark .read diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 587cfa9bd6647..1636ae08cd08e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -375,7 +375,9 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("SPARK-25700: do not read schema when writing in other modes except append and overwrite") { + // TODO: enable this one when all tests passed. + ignore("SPARK-25700: do not read schema when writing in other modes" + + " except append and overwrite") { withTempPath { file => val cls = classOf[SimpleWriteOnlyDataSource] val path = file.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 9f969473da612..2569085bec086 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -428,7 +428,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be val message = intercept[AnalysisException] { testRead(spark.read.csv(), Seq.empty, schema) }.getMessage - assert(message.contains("Unable to infer schema for CSV. It must be specified manually.")) + assert(message.toLowerCase(Locale.ROOT).contains("unable to infer schema for csv")) testRead(spark.read.csv(dir), data, schema) testRead(spark.read.csv(dir, dir), data ++ data, schema) From b363a003f48e6d1a6fbba750daafbf8c79bd8444 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 19 Mar 2019 17:42:24 +0800 Subject: [PATCH 2/8] disable csv write path --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/execution/datasources/v2/FileTable.scala | 2 +- .../org/apache/spark/sql/FileBasedDataSourceSuite.scala | 6 ++++-- .../org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala | 3 +-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 20f4080c98590..700f7e2ec183c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1482,7 +1482,7 @@ object SQLConf { " register class names for which data source V2 write paths are disabled. Writes from these" + " sources will fall back to the V1 sources.") .stringConf - .createWithDefault("orc") + .createWithDefault("orc,csv") val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .doc("A comma-separated list of fully qualified data source register class names for which" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 4b35df355b6e7..b5fccc15ffe88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -40,7 +40,7 @@ abstract class FileTable( // Hadoop Configurations are case sensitive. val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(paths, hadoopConf, - checkEmptyGlobPath = true, checkFilesExist = true) + checkEmptyGlobPath = true, checkFilesExist = false) val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) new InMemoryFileIndex( sparkSession, rootPathsSpecified, caseSensitiveMap, userSpecifiedSchema, fileStatusCache) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index a41dd43818873..9b5b71f349c0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -329,7 +329,8 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { withTempDir { dir => val tempDir = new File(dir, "files").getCanonicalPath - Seq(true, false).foreach { useV1 => + // TODO: test file source V2 after write path is fixed. + Seq(true).foreach { useV1 => val useV1List = if (useV1) { "orc,csv" } else { @@ -373,7 +374,8 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { - Seq(true, false).foreach { useV1 => + // TODO: test file source V2 after write path is fixed. + Seq(true).foreach { useV1 => val useV1List = if (useV1) { "orc,csv" } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 1636ae08cd08e..2c3033772d032 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -375,8 +375,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - // TODO: enable this one when all tests passed. - ignore("SPARK-25700: do not read schema when writing in other modes" + + test("SPARK-25700: do not read schema when writing in other modes" + " except append and overwrite") { withTempPath { file => val cls = classOf[SimpleWriteOnlyDataSource] From 52880117b701214e2d58125e6e72dbfcb8adfe21 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 20 Mar 2019 16:33:32 +0800 Subject: [PATCH 3/8] fix --- .../apache/spark/sql/execution/datasources/v2/FileTable.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index b5fccc15ffe88..4b35df355b6e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -40,7 +40,7 @@ abstract class FileTable( // Hadoop Configurations are case sensitive. val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(paths, hadoopConf, - checkEmptyGlobPath = true, checkFilesExist = false) + checkEmptyGlobPath = true, checkFilesExist = true) val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) new InMemoryFileIndex( sparkSession, rootPathsSpecified, caseSensitiveMap, userSpecifiedSchema, fileStatusCache) From 3f6df2d2359fd25bf9d15649ea3b5b6ba520eb20 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 21 Mar 2019 19:35:29 +0800 Subject: [PATCH 4/8] address comments --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../datasources/v2/csv/CSVPartitionReaderFactory.scala | 2 +- .../scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala | 4 ++-- .../org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala | 3 +-- .../apache/spark/sql/test/DataFrameReaderWriterSuite.scala | 2 +- 5 files changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 700f7e2ec183c..a4ca1a0a72ae2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1482,7 +1482,7 @@ object SQLConf { " register class names for which data source V2 write paths are disabled. Writes from these" + " sources will fall back to the V1 sources.") .stringConf - .createWithDefault("orc,csv") + .createWithDefault("csv,orc") val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .doc("A comma-separated list of fully qualified data source register class names for which" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala index 312ee48aaa618..6667891a07098 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration * A factory used to create CSV readers. * * @param sqlConf SQL configuration. - * @param broadcastedConf Broadcast serializable Hadoop Configuration. + * @param broadcastedConf Broadcasted serializable Hadoop Configuration. * @param dataSchema Schema of CSV files. * @param partitionSchema Schema of partitions. * @param readSchema Required schema in the batch scan. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 9b5b71f349c0c..2e93ff9ee047d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -332,7 +332,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo // TODO: test file source V2 after write path is fixed. Seq(true).foreach { useV1 => val useV1List = if (useV1) { - "orc,csv" + "csv,orc" } else { "" } @@ -377,7 +377,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo // TODO: test file source V2 after write path is fixed. Seq(true).foreach { useV1 => val useV1List = if (useV1) { - "orc,csv" + "csv,orc" } else { "" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 2c3033772d032..587cfa9bd6647 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -375,8 +375,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("SPARK-25700: do not read schema when writing in other modes" + - " except append and overwrite") { + test("SPARK-25700: do not read schema when writing in other modes except append and overwrite") { withTempPath { file => val cls = classOf[SimpleWriteOnlyDataSource] val path = file.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 2569085bec086..9f969473da612 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -428,7 +428,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be val message = intercept[AnalysisException] { testRead(spark.read.csv(), Seq.empty, schema) }.getMessage - assert(message.toLowerCase(Locale.ROOT).contains("unable to infer schema for csv")) + assert(message.contains("Unable to infer schema for CSV. It must be specified manually.")) testRead(spark.read.csv(dir), data, schema) testRead(spark.read.csv(dir, dir), data ++ data, schema) From 256e17793f5b2910123d7169e8d8a7ae130b4345 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 21 Mar 2019 23:06:27 +0800 Subject: [PATCH 5/8] use caseSensitive Map for newHadoopConfWithOptions --- .../spark/sql/execution/datasources/v2/csv/CSVScan.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index a5b1fe019041e..c3cc80b25720e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -68,8 +68,9 @@ case class CSVScan( ) } - val hadoopConf = - sparkSession.sessionState.newHadoopConfWithOptions(optionsAsScala) + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) val broadcastedConf = sparkSession.sparkContext.broadcast( new SerializableConfiguration(hadoopConf)) CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, From 2be52771da5d11ad8626227a529070fa84398adb Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 22 Mar 2019 00:00:53 +0800 Subject: [PATCH 6/8] fix test --- .../org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 9f969473da612..2569085bec086 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -428,7 +428,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be val message = intercept[AnalysisException] { testRead(spark.read.csv(), Seq.empty, schema) }.getMessage - assert(message.contains("Unable to infer schema for CSV. It must be specified manually.")) + assert(message.toLowerCase(Locale.ROOT).contains("unable to infer schema for csv")) testRead(spark.read.csv(dir), data, schema) testRead(spark.read.csv(dir, dir), data ++ data, schema) From 7eb54a35c8abf86fd6c15f087d171e7b896ba34a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 22 Mar 2019 15:55:34 +0800 Subject: [PATCH 7/8] address comments --- .../datasources/csv/CSVFileFormat.scala | 30 +--------- .../datasources/csv/CsvOutputWriter.scala | 57 +++++++++++++++++++ .../datasources/v2/FileDataSourceV2.scala | 4 ++ .../v2/PartitionReaderFromIterator.scala | 6 +- .../datasources/v2/csv/CSVDataSourceV2.scala | 8 +-- .../v2/csv/CSVPartitionReaderFactory.scala | 4 +- .../datasources/v2/csv/CSVScan.scala | 5 +- .../datasources/v2/csv/CSVWriteBuilder.scala | 2 +- .../datasources/v2/orc/OrcDataSourceV2.scala | 4 -- .../spark/sql/FileBasedDataSourceSuite.scala | 1 + 10 files changed, 73 insertions(+), 48 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 4eceb86b44542..c8de53a17acad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -118,7 +118,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { throw new AnalysisException( "Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the\n" + "referenced columns only include the internal corrupt record column\n" + - s"(named _corrupt_record by default). For example:\n" + + "(named _corrupt_record by default). For example:\n" + "spark.read.schema(schema).csv(file).filter($\"_corrupt_record\".isNotNull).count()\n" + "and spark.read.schema(schema).csv(file).select(\"_corrupt_record\").show().\n" + "Instead, you can cache or save the parsed results and then send the same query.\n" + @@ -163,31 +163,3 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } -class CsvOutputWriter( - path: String, - dataSchema: StructType, - context: TaskAttemptContext, - params: CSVOptions) extends OutputWriter with Logging { - - private var univocityGenerator: Option[UnivocityGenerator] = None - - if (params.headerFlag) { - val gen = getGen() - gen.writeHeaders() - } - - private def getGen(): UnivocityGenerator = univocityGenerator.getOrElse { - val charset = Charset.forName(params.charset) - val os = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) - val newGen = new UnivocityGenerator(dataSchema, os, params) - univocityGenerator = Some(newGen) - newGen - } - - override def write(row: InternalRow): Unit = { - val gen = getGen() - gen.write(row) - } - - override def close(): Unit = univocityGenerator.foreach(_.close()) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala new file mode 100644 index 0000000000000..97d0a789b20ad --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.csv + +import java.nio.charset.Charset + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityGenerator} +import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} +import org.apache.spark.sql.types.StructType + +private[sql] class CsvOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext, + params: CSVOptions) extends OutputWriter with Logging { + + private var univocityGenerator: Option[UnivocityGenerator] = None + + if (params.headerFlag) { + val gen = getGen() + gen.writeHeaders() + } + + private def getGen(): UnivocityGenerator = univocityGenerator.getOrElse { + val charset = Charset.forName(params.charset) + val os = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) + val newGen = new UnivocityGenerator(dataSchema, os, params) + univocityGenerator = Some(newGen) + newGen + } + + override def write(row: InternalRow): Unit = { + val gen = getGen() + gen.write(row) + } + + override def close(): Unit = univocityGenerator.foreach(_.close()) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index e9c7a1bb749db..ebe7fee312e89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -47,4 +47,8 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { Option(map.get("path")).toSeq } } + + protected def getTableName(paths: Seq[String]): String = { + shortName() + ":" + paths.mkString(";") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderFromIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderFromIterator.scala index f6cd691c5434f..f9dfcf448a3ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderFromIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderFromIterator.scala @@ -20,18 +20,18 @@ import org.apache.spark.sql.sources.v2.reader.PartitionReader class PartitionReaderFromIterator[InternalRow]( iter: Iterator[InternalRow]) extends PartitionReader[InternalRow] { - private var nextValue: InternalRow = _ + private var currentValue: InternalRow = _ override def next(): Boolean = { if (iter.hasNext) { - nextValue = iter.next() + currentValue = iter.next() true } else { false } } - override def get(): InternalRow = nextValue + override def get(): InternalRow = currentValue override def close(): Unit = {} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala index 56f1169956d67..4ecd9cdc32acf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala @@ -16,9 +16,9 @@ */ package org.apache.spark.sql.execution.datasources.v2.csv -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat -import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.sources.v2.Table import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -29,10 +29,6 @@ class CSVDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "csv" - private def getTableName(paths: Seq[String]): String = { - shortName() + ":" + paths.mkString(";") - } - override def getTable(options: CaseInsensitiveStringMap): Table = { val paths = getPaths(options) val tableName = getTableName(paths) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala index 6667891a07098..e2d50282e9cba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala @@ -35,6 +35,7 @@ import org.apache.spark.util.SerializableConfiguration * @param dataSchema Schema of CSV files. * @param partitionSchema Schema of partitions. * @param readSchema Required schema in the batch scan. + * @param parsedOptions Options for parsing CSV files. */ case class CSVPartitionReaderFactory( sqlConf: SQLConf, @@ -63,10 +64,9 @@ case class CSVPartitionReaderFactory( file, parser, headerChecker, - readSchema) + readDataSchema) val fileReader = new PartitionReaderFromIterator[InternalRow](iter) new PartitionReaderWithPartitionValues(fileReader, readDataSchema, partitionSchema, file.partitionValues) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index c3cc80b25720e..35c6a668f22a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -39,9 +39,8 @@ case class CSVScan( options: CaseInsensitiveStringMap) extends TextBasedFileScan(sparkSession, fileIndex, readSchema, options) { - private val optionsAsScala = options.asScala.toMap private lazy val parsedOptions: CSVOptions = new CSVOptions( - optionsAsScala, + options.asScala.toMap, columnPruning = sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -59,7 +58,7 @@ case class CSVScan( throw new AnalysisException( "Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the\n" + "referenced columns only include the internal corrupt record column\n" + - s"(named _corrupt_record by default). For example:\n" + + "(named _corrupt_record by default). For example:\n" + "spark.read.schema(schema).csv(file).filter($\"_corrupt_record\".isNotNull).count()\n" + "and spark.read.schema(schema).csv(file).select(\"_corrupt_record\").show().\n" + "Instead, you can cache or save the parsed results and then send the same query.\n" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala index 70cda53f2461e..bb26d2f92d74b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter, O import org.apache.spark.sql.execution.datasources.csv.CsvOutputWriter import org.apache.spark.sql.execution.datasources.v2.FileWriteBuilder import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap class CSVWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index 900c94e937ffc..36e7e12e41cec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -29,10 +29,6 @@ class OrcDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "orc" - private def getTableName(paths: Seq[String]): String = { - shortName() + ":" + paths.mkString(";") - } - override def getTable(options: CaseInsensitiveStringMap): Table = { val paths = getPaths(options) val tableName = getTableName(paths) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 2e93ff9ee047d..1d30cbfbaf1a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -471,6 +471,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } test("SPARK-25237 compute correct input metrics in FileScanRDD") { + // TODO: Test CSV V2 as well after it implements [[SupportsReportStatistics]]. withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "csv") { withTempPath { p => val path = p.getAbsolutePath From 689f17e4930694d1e7585ba82e0964507121868a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 22 Mar 2019 15:59:39 +0800 Subject: [PATCH 8/8] revise --- .../spark/sql/execution/datasources/csv/CsvOutputWriter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala index 97d0a789b20ad..3ff36bfde3cca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityGenerator} import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} import org.apache.spark.sql.types.StructType -private[sql] class CsvOutputWriter( +class CsvOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext,