From 17c2d50c9d4788aaa68f7f57fc873762940a8e9d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 10 Sep 2016 08:31:26 -0700 Subject: [PATCH 1/7] fix --- .../apache/spark/sql/DataFrameReader.scala | 3 +- .../command/createDataSourceTables.scala | 5 +- .../execution/datasources/DataSource.scala | 50 ++++++++++---- .../datasources/DataSourceStrategy.scala | 2 +- .../spark/sql/execution/datasources/ddl.scala | 3 +- .../streaming/FileStreamSource.scala | 2 +- .../sql/streaming/DataStreamReader.scala | 2 +- .../datasources/json/JsonSuite.scala | 4 +- .../spark/sql/sources/InsertSuite.scala | 20 ++++++ .../spark/sql/sources/TableScanSuite.scala | 68 ++++++++++++------- .../sql/test/DataFrameReaderWriterSuite.scala | 33 +++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 15 ++++ 13 files changed, 161 insertions(+), 50 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 93bf74d06b71d..1dcc75afbf3da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -144,7 +144,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { DataSource.apply( sparkSession, paths = paths, - userSpecifiedSchema = userSpecifiedSchema, + inputSchema = userSpecifiedSchema, + isSchemaFromUsers = true, className = source, options = extraOptions.toMap).resolveRelation()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index b1830e6cf3ea8..5148ad170596e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -60,7 +60,8 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo val dataSource: BaseRelation = DataSource( sparkSession = sparkSession, - userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema), + inputSchema = if (table.schema.isEmpty) None else Some(table.schema), + isSchemaFromUsers = true, className = table.provider.get, bucketSpec = table.bucketSpec, options = table.storage.properties).resolveRelation() @@ -156,7 +157,7 @@ case class CreateDataSourceTableAsSelectCommand( // Check if the specified data source match the data source of the existing table. val dataSource = DataSource( sparkSession = sparkSession, - userSpecifiedSchema = Some(query.schema.asNullable), + inputSchema = Some(query.schema.asNullable), partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, className = provider, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 71807b771a95f..0df13e320fe2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -60,8 +60,9 @@ import org.apache.spark.util.Utils * * @param paths A list of file system paths that hold data. These will be globbed before and * qualified. This option only works when reading from a [[FileFormat]]. - * @param userSpecifiedSchema An optional specification of the schema of the data. When present - * we skip attempting to infer the schema. + * @param inputSchema An optional specification of the schema of the data. When present we skip + * attempting to infer the schema. + * @param isSchemaFromUsers A flag to indicate whether the schema is specified by users. * @param partitionColumns A list of column names that the relation is partitioned by. When this * list is empty, the relation is unpartitioned. * @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data. @@ -70,7 +71,8 @@ case class DataSource( sparkSession: SparkSession, className: String, paths: Seq[String] = Nil, - userSpecifiedSchema: Option[StructType] = None, + inputSchema: Option[StructType] = None, + isSchemaFromUsers: Boolean = false, partitionColumns: Seq[String] = Seq.empty, bucketSpec: Option[BucketSpec] = None, options: Map[String, String] = Map.empty) extends Logging { @@ -186,7 +188,7 @@ case class DataSource( } private def inferFileFormatSchema(format: FileFormat): StructType = { - userSpecifiedSchema.orElse { + inputSchema.orElse { val caseInsensitiveOptions = new CaseInsensitiveMap(options) val allPaths = caseInsensitiveOptions.get("path") val globbedPaths = allPaths.toSeq.flatMap { path => @@ -210,7 +212,7 @@ case class DataSource( providingClass.newInstance() match { case s: StreamSourceProvider => val (name, schema) = s.sourceSchema( - sparkSession.sqlContext, userSpecifiedSchema, className, options) + sparkSession.sqlContext, inputSchema, className, options) SourceInfo(name, schema) case format: FileFormat => @@ -233,7 +235,7 @@ case class DataSource( val isSchemaInferenceEnabled = sparkSession.conf.get(SQLConf.STREAMING_SCHEMA_INFERENCE) val isTextSource = providingClass == classOf[text.TextFileFormat] // If the schema inference is disabled, only text sources require schema to be specified - if (!isSchemaInferenceEnabled && !isTextSource && userSpecifiedSchema.isEmpty) { + if (!isSchemaInferenceEnabled && !isTextSource && inputSchema.isEmpty) { throw new IllegalArgumentException( "Schema must be specified when creating a streaming source DataFrame. " + "If some files already exist in the directory, then depending on the file format " + @@ -252,8 +254,7 @@ case class DataSource( def createSource(metadataPath: String): Source = { providingClass.newInstance() match { case s: StreamSourceProvider => - s.createSource( - sparkSession.sqlContext, metadataPath, userSpecifiedSchema, className, options) + s.createSource(sparkSession.sqlContext, metadataPath, inputSchema, className, options) case format: FileFormat => val path = new CaseInsensitiveMap(options).getOrElse("path", { @@ -312,13 +313,28 @@ case class DataSource( } } + /** + * Check whether users are allowed to provide schema for this data source. + */ + def checkSchemaAssignable(): Unit = { + val notExtendedSchemaRelationProvider = try { + !classOf[SchemaRelationProvider].isAssignableFrom(providingClass) + } catch { + case NonFatal(e) => false + } + if (notExtendedSchemaRelationProvider) { + throw new AnalysisException(s"$providingClass does not allow user-specified schemas") + } + } + + /** * Create a resolved [[BaseRelation]] that can be used to read data from or write data into this * [[DataSource]] */ def resolveRelation(): BaseRelation = { val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val relation = (providingClass.newInstance(), userSpecifiedSchema) match { + val relation = (providingClass.newInstance(), inputSchema) match { // TODO: Throw when too much is given. case (dataSource: SchemaRelationProvider, Some(schema)) => dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema) @@ -326,8 +342,12 @@ case class DataSource( dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions) case (_: SchemaRelationProvider, None) => throw new AnalysisException(s"A schema needs to be specified when using $className.") - case (_: RelationProvider, Some(_)) => - throw new AnalysisException(s"$className does not allow user-specified schemas.") + case (dataSource: RelationProvider, Some(_)) => + if (isSchemaFromUsers) { + throw new AnalysisException(s"$className does not allow user-specified schemas.") + } else { + dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions) + } // We are reading from the results of a streaming query. Load files from the metadata log // instead of listing them using HDFS APIs. @@ -335,7 +355,7 @@ case class DataSource( if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) val fileCatalog = new MetadataLogFileCatalog(sparkSession, basePath) - val dataSchema = userSpecifiedSchema.orElse { + val dataSchema = inputSchema.orElse { format.inferSchema( sparkSession, caseInsensitiveOptions, @@ -375,7 +395,7 @@ case class DataSource( // If they gave a schema, then we try and figure out the types of the partition columns // from that schema. - val partitionSchema = userSpecifiedSchema.map { schema => + val partitionSchema = inputSchema.map { schema => StructType( partitionColumns.map { c => // TODO: Case sensitivity. @@ -389,7 +409,7 @@ case class DataSource( new ListingFileCatalog( sparkSession, globbedPaths, options, partitionSchema) - val dataSchema = userSpecifiedSchema.map { schema => + val dataSchema = inputSchema.map { schema => val equality = sparkSession.sessionState.conf.resolver StructType(schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) }.orElse { @@ -499,7 +519,7 @@ case class DataSource( mode) sparkSession.sessionState.executePlan(plan).toRdd // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. - copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() + copy(inputSchema = Some(data.schema.asNullable)).resolveRelation() case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c8ad5b303491f..b151a46d6850f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -201,7 +201,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] val dataSource = DataSource( sparkSession, - userSpecifiedSchema = Some(table.schema), + inputSchema = Some(table.schema), partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, className = table.provider.get, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 1b1e2123b7c47..a1e91cdcd7a3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -55,7 +55,8 @@ case class CreateTempViewUsing( def run(sparkSession: SparkSession): Seq[Row] = { val dataSource = DataSource( sparkSession, - userSpecifiedSchema = userSpecifiedSchema, + inputSchema = userSpecifiedSchema, + isSchemaFromUsers = true, className = provider, options = options) sparkSession.sessionState.catalog.createTempView( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 42fb454c2d158..4bdc052375591 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -133,7 +133,7 @@ class FileStreamSource( DataSource( sparkSession, paths = files.map(_.path), - userSpecifiedSchema = Some(schema), + inputSchema = Some(schema), className = fileFormatClassName, options = sourceOptions.optionMapWithoutPath) Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 3ad1125229c97..f1a258746e6ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -136,7 +136,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo val dataSource = DataSource( sparkSession, - userSpecifiedSchema = userSpecifiedSchema, + inputSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) Dataset.ofRows(sparkSession, StreamingRelation(dataSource)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 3d533c14e18e7..4b998ef389ac2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1346,7 +1346,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val d1 = DataSource( spark, - userSpecifiedSchema = None, + inputSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, className = classOf[JsonFileFormat].getCanonicalName, @@ -1354,7 +1354,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val d2 = DataSource( spark, - userSpecifiedSchema = None, + inputSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, className = classOf[JsonFileFormat].getCanonicalName, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 6454d716ec0db..29902bbf89acc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -65,6 +65,26 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { ) } + test("insert into a temp view that does not point to an insertable data source") { + import testImplicits._ + withTempView("t1", "t2") { + sql( + """ + |CREATE TEMPORARY TABLE t1 + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10') + """.stripMargin) + sparkContext.parallelize(1 to 10).toDF("a").createOrReplaceTempView("t2") + + val message = intercept[AnalysisException] { + sql("INSERT INTO TABLE t1 SELECT a FROM t2") + }.getMessage + assert(message.contains("does not allow insertion")) + } + } + test("PreInsert casting and renaming") { sql( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index e8fed039fa993..874638e0bcd10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -345,34 +345,54 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { (1 to 10).map(Row(_)).toSeq) } + test("create a temp table that does not have a path in the option") { + Seq("TEMPORARY VIEW", "TABLE").foreach { tableType => + val tableName = "relationProvierWithSchema" + withTable(tableName) { + sql( + s""" + |CREATE $tableType $tableName + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + checkAnswer(spark.table(tableName), spark.range(1, 11).toDF()) + } + } + } + test("exceptions") { // Make sure we do throw correct exception when users use a relation provider that // only implements the RelationProvider or the SchemaRelationProvider. - val schemaNotAllowed = intercept[Exception] { - sql( - """ - |CREATE TEMPORARY VIEW relationProvierWithSchema (i int) - |USING org.apache.spark.sql.sources.SimpleScanSource - |OPTIONS ( - | From '1', - | To '10' - |) - """.stripMargin) - } - assert(schemaNotAllowed.getMessage.contains("does not allow user-specified schemas")) - - val schemaNeeded = intercept[Exception] { - sql( - """ - |CREATE TEMPORARY VIEW schemaRelationProvierWithoutSchema - |USING org.apache.spark.sql.sources.AllDataTypesScanSource - |OPTIONS ( - | From '1', - | To '10' - |) - """.stripMargin) + Seq("TEMPORARY VIEW", "TABLE").foreach { tableType => + val schemaNotAllowed = intercept[Exception] { + sql( + s""" + |CREATE $tableType relationProvierWithSchema (i int) + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + } + assert(schemaNotAllowed.getMessage.contains("does not allow user-specified schemas")) + + val schemaNeeded = intercept[Exception] { + sql( + s""" + |CREATE $tableType schemaRelationProvierWithoutSchema + |USING org.apache.spark.sql.sources.AllDataTypesScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + } + assert(schemaNeeded.getMessage.contains("A schema needs to be specified when using")) } - assert(schemaNeeded.getMessage.contains("A schema needs to be specified when using")) } test("SPARK-5196 schema field with comment") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 63b0e4588e4a6..dadd8b28706e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -292,6 +292,39 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be Option(dir).map(spark.read.format("org.apache.spark.sql.test").load) } + test("read a data source that does not extend SchemaRelationProvider") { + val dfReader = spark.read + .option("from", "1") + .option("TO", "10") + .format("org.apache.spark.sql.sources.SimpleScanSource") + + // when users do not specify the schema + checkAnswer(dfReader.load(), spark.range(1, 11).toDF()) + + // when users specify the schema + val inputSchema = new StructType().add("s", IntegerType, nullable = false) + val e = intercept[AnalysisException] { dfReader.schema(inputSchema).load() } + assert(e.getMessage.contains( + "org.apache.spark.sql.sources.SimpleScanSource does not allow user-specified schemas")) + } + + test("read a data source that does not extend RelationProvider") { + val dfReader = spark.read + .option("from", "1") + .option("TO", "10") + .option("option_with_underscores", "someval") + .option("option.with.dots", "someval") + .format("org.apache.spark.sql.sources.AllDataTypesScanSource") + + // when users do not specify the schema + val e = intercept[AnalysisException] { dfReader.load() } + assert(e.getMessage.contains("A schema needs to be specified when using")) + + // when users specify the schema + val inputSchema = new StructType().add("s", StringType, nullable = false) + assert(dfReader.schema(inputSchema).load().count() == 10) + } + test("text - API and behavior regarding schema") { // Writer spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8410a2e4a47ca..b2e6ddbfa1862 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -74,7 +74,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val dataSource = DataSource( sparkSession, - userSpecifiedSchema = Some(table.schema), + inputSchema = Some(table.schema), partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, className = table.provider.get, @@ -278,7 +278,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log DataSource( sparkSession = sparkSession, paths = paths, - userSpecifiedSchema = Some(metastoreRelation.schema), + inputSchema = Some(metastoreRelation.schema), bucketSpec = bucketSpec, options = options, className = fileType).resolveRelation(), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 3466733d7fdcd..37d2d9d4fc149 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1248,6 +1248,21 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + test("create a temp table that does not have a path in the option") { + withTable("t1") { + sql( + """ + |CREATE TABLE t1 (i int) + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10') + """.stripMargin) + + spark.table("t1").show() + } + } + test("read table with corrupted schema") { try { val schema = StructType(StructField("int", IntegerType, true) :: Nil) From 00a49fe60f86775e19f038791a766195d506087a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 10 Sep 2016 08:41:18 -0700 Subject: [PATCH 2/7] clean --- .../sql/hive/MetastoreDataSourcesSuite.scala | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 37d2d9d4fc149..3466733d7fdcd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1248,21 +1248,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } - test("create a temp table that does not have a path in the option") { - withTable("t1") { - sql( - """ - |CREATE TABLE t1 (i int) - |USING org.apache.spark.sql.sources.SimpleScanSource - |OPTIONS ( - | From '1', - | To '10') - """.stripMargin) - - spark.table("t1").show() - } - } - test("read table with corrupted schema") { try { val schema = StructType(StructField("int", IntegerType, true) :: Nil) From 335e0d6d5a19b30ec000db8d935869e006dd81e7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 10 Sep 2016 08:42:11 -0700 Subject: [PATCH 3/7] clean --- .../sql/execution/datasources/DataSource.scala | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 0df13e320fe2a..710d39abe23f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -313,21 +313,6 @@ case class DataSource( } } - /** - * Check whether users are allowed to provide schema for this data source. - */ - def checkSchemaAssignable(): Unit = { - val notExtendedSchemaRelationProvider = try { - !classOf[SchemaRelationProvider].isAssignableFrom(providingClass) - } catch { - case NonFatal(e) => false - } - if (notExtendedSchemaRelationProvider) { - throw new AnalysisException(s"$providingClass does not allow user-specified schemas") - } - } - - /** * Create a resolved [[BaseRelation]] that can be used to read data from or write data into this * [[DataSource]] From 4ab1b8a45c9a8b9ed1f7ee85202eddf397235df4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 10 Sep 2016 09:12:26 -0700 Subject: [PATCH 4/7] add one more test case --- .../spark/sql/sources/TableScanSuite.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 874638e0bcd10..9d81a2ce5e4d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -395,6 +395,24 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { } } + test("read the data source tables that do not extend SchemaRelationProvider") { + Seq("TEMPORARY VIEW", "TABLE").foreach { tableType => + val tableName = "relationProvierWithSchema" + withTable (tableName) { + sql( + s""" + |CREATE $tableType $tableName + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + checkAnswer(spark.table(tableName), spark.range(1, 11).toDF()) + } + } + } + test("SPARK-5196 schema field with comment") { sql( """ From 55ee86456eb0823f91a1130f5823f32b1b502ef6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 19 Sep 2016 14:07:18 -0700 Subject: [PATCH 5/7] address comments. --- .../apache/spark/sql/DataFrameReader.scala | 3 +- .../command/createDataSourceTables.scala | 5 ++- .../execution/datasources/DataSource.scala | 36 +++++++++---------- .../datasources/DataSourceStrategy.scala | 2 +- .../spark/sql/execution/datasources/ddl.scala | 3 +- .../streaming/FileStreamSource.scala | 2 +- .../sql/streaming/DataStreamReader.scala | 2 +- .../datasources/json/JsonSuite.scala | 4 +-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +-- 9 files changed, 29 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index acb9e3dbf2c5b..30f39c70fe0bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -144,8 +144,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { DataSource.apply( sparkSession, paths = paths, - inputSchema = userSpecifiedSchema, - isSchemaFromUsers = true, + userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap).resolveRelation()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 54832e03a0b68..d8e20b09c1add 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -64,8 +64,7 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo val dataSource: BaseRelation = DataSource( sparkSession = sparkSession, - inputSchema = if (table.schema.isEmpty) None else Some(table.schema), - isSchemaFromUsers = true, + userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema), className = table.provider.get, bucketSpec = table.bucketSpec, options = table.storage.properties).resolveRelation() @@ -165,7 +164,7 @@ case class CreateDataSourceTableAsSelectCommand( // Check if the specified data source match the data source of the existing table. val dataSource = DataSource( sparkSession = sparkSession, - inputSchema = Some(query.schema.asNullable), + userSpecifiedSchema = Some(query.schema.asNullable), partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, className = provider, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index c292ef83a78ab..ba36f6dbb2b4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -60,9 +60,8 @@ import org.apache.spark.util.Utils * * @param paths A list of file system paths that hold data. These will be globbed before and * qualified. This option only works when reading from a [[FileFormat]]. - * @param inputSchema An optional specification of the schema of the data. When present we skip - * attempting to infer the schema. - * @param isSchemaFromUsers A flag to indicate whether the schema is specified by users. + * @param userSpecifiedSchema An optional specification of the schema of the data. When present + * we skip attempting to infer the schema. * @param partitionColumns A list of column names that the relation is partitioned by. When this * list is empty, the relation is unpartitioned. * @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data. @@ -71,8 +70,7 @@ case class DataSource( sparkSession: SparkSession, className: String, paths: Seq[String] = Nil, - inputSchema: Option[StructType] = None, - isSchemaFromUsers: Boolean = false, + userSpecifiedSchema: Option[StructType] = None, partitionColumns: Seq[String] = Seq.empty, bucketSpec: Option[BucketSpec] = None, options: Map[String, String] = Map.empty) extends Logging { @@ -189,7 +187,7 @@ case class DataSource( } private def inferFileFormatSchema(format: FileFormat): StructType = { - inputSchema.orElse { + userSpecifiedSchema.orElse { val caseInsensitiveOptions = new CaseInsensitiveMap(options) val allPaths = caseInsensitiveOptions.get("path") val globbedPaths = allPaths.toSeq.flatMap { path => @@ -213,7 +211,7 @@ case class DataSource( providingClass.newInstance() match { case s: StreamSourceProvider => val (name, schema) = s.sourceSchema( - sparkSession.sqlContext, inputSchema, className, options) + sparkSession.sqlContext, userSpecifiedSchema, className, options) SourceInfo(name, schema) case format: FileFormat => @@ -236,7 +234,7 @@ case class DataSource( val isSchemaInferenceEnabled = sparkSession.sessionState.conf.streamingSchemaInference val isTextSource = providingClass == classOf[text.TextFileFormat] // If the schema inference is disabled, only text sources require schema to be specified - if (!isSchemaInferenceEnabled && !isTextSource && inputSchema.isEmpty) { + if (!isSchemaInferenceEnabled && !isTextSource && userSpecifiedSchema.isEmpty) { throw new IllegalArgumentException( "Schema must be specified when creating a streaming source DataFrame. " + "If some files already exist in the directory, then depending on the file format " + @@ -255,7 +253,8 @@ case class DataSource( def createSource(metadataPath: String): Source = { providingClass.newInstance() match { case s: StreamSourceProvider => - s.createSource(sparkSession.sqlContext, metadataPath, inputSchema, className, options) + s.createSource( + sparkSession.sqlContext, metadataPath, userSpecifiedSchema, className, options) case format: FileFormat => val path = new CaseInsensitiveMap(options).getOrElse("path", { @@ -320,7 +319,7 @@ case class DataSource( */ def resolveRelation(): BaseRelation = { val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val relation = (providingClass.newInstance(), inputSchema) match { + val relation = (providingClass.newInstance(), userSpecifiedSchema) match { // TODO: Throw when too much is given. case (dataSource: SchemaRelationProvider, Some(schema)) => dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema) @@ -328,12 +327,13 @@ case class DataSource( dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions) case (_: SchemaRelationProvider, None) => throw new AnalysisException(s"A schema needs to be specified when using $className.") - case (dataSource: RelationProvider, Some(_)) => - if (isSchemaFromUsers) { - throw new AnalysisException(s"$className does not allow user-specified schemas.") - } else { + case (dataSource: RelationProvider, Some(schema)) => + val baseRelation = dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions) + if (baseRelation.schema != schema) { + throw new AnalysisException(s"$className does not allow user-specified schemas.") } + baseRelation // We are reading from the results of a streaming query. Load files from the metadata log // instead of listing them using HDFS APIs. @@ -341,7 +341,7 @@ case class DataSource( if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) val fileCatalog = new MetadataLogFileCatalog(sparkSession, basePath) - val dataSchema = inputSchema.orElse { + val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, caseInsensitiveOptions, @@ -381,7 +381,7 @@ case class DataSource( // If they gave a schema, then we try and figure out the types of the partition columns // from that schema. - val partitionSchema = inputSchema.map { schema => + val partitionSchema = userSpecifiedSchema.map { schema => StructType( partitionColumns.map { c => // TODO: Case sensitivity. @@ -395,7 +395,7 @@ case class DataSource( new ListingFileCatalog( sparkSession, globbedPaths, options, partitionSchema) - val dataSchema = inputSchema.map { schema => + val dataSchema = userSpecifiedSchema.map { schema => val equality = sparkSession.sessionState.conf.resolver StructType(schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) }.orElse { @@ -505,7 +505,7 @@ case class DataSource( mode) sparkSession.sessionState.executePlan(plan).toRdd // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. - copy(inputSchema = Some(data.schema.asNullable)).resolveRelation() + copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index b151a46d6850f..c8ad5b303491f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -201,7 +201,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] val dataSource = DataSource( sparkSession, - inputSchema = Some(table.schema), + userSpecifiedSchema = Some(table.schema), partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, className = table.provider.get, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 0ba7173d27376..fa95af2648cf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -55,8 +55,7 @@ case class CreateTempViewUsing( def run(sparkSession: SparkSession): Seq[Row] = { val dataSource = DataSource( sparkSession, - inputSchema = userSpecifiedSchema, - isSchemaFromUsers = true, + userSpecifiedSchema = userSpecifiedSchema, className = provider, options = options) sparkSession.sessionState.catalog.createTempView( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 4bdc052375591..42fb454c2d158 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -133,7 +133,7 @@ class FileStreamSource( DataSource( sparkSession, paths = files.map(_.path), - inputSchema = Some(schema), + userSpecifiedSchema = Some(schema), className = fileFormatClassName, options = sourceOptions.optionMapWithoutPath) Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 9aea34df6eb5c..9d174051bc923 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -136,7 +136,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo val dataSource = DataSource( sparkSession, - inputSchema = userSpecifiedSchema, + userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) Dataset.ofRows(sparkSession, StreamingRelation(dataSource)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4b998ef389ac2..3d533c14e18e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1346,7 +1346,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val d1 = DataSource( spark, - inputSchema = None, + userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, className = classOf[JsonFileFormat].getCanonicalName, @@ -1354,7 +1354,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val d2 = DataSource( spark, - inputSchema = None, + userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, className = classOf[JsonFileFormat].getCanonicalName, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index b2e6ddbfa1862..8410a2e4a47ca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -74,7 +74,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val dataSource = DataSource( sparkSession, - inputSchema = Some(table.schema), + userSpecifiedSchema = Some(table.schema), partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, className = table.provider.get, @@ -278,7 +278,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log DataSource( sparkSession = sparkSession, paths = paths, - inputSchema = Some(metastoreRelation.schema), + userSpecifiedSchema = Some(metastoreRelation.schema), bucketSpec = bucketSpec, options = options, className = fileType).resolveRelation(), From 59d06f8801bd12f51fa07b1c22d478d4f1181b36 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 19 Sep 2016 23:05:17 -0700 Subject: [PATCH 6/7] address comments --- .../test/scala/org/apache/spark/sql/sources/InsertSuite.scala | 2 +- .../scala/org/apache/spark/sql/sources/TableScanSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 29902bbf89acc..5eb54643f204f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -70,7 +70,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { withTempView("t1", "t2") { sql( """ - |CREATE TEMPORARY TABLE t1 + |CREATE TEMPORARY VIEW t1 |USING org.apache.spark.sql.sources.SimpleScanSource |OPTIONS ( | From '1', diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 9d81a2ce5e4d5..ee40cddf0280d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -345,7 +345,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { (1 to 10).map(Row(_)).toSeq) } - test("create a temp table that does not have a path in the option") { + test("create a temp view or a persistent table that does not need a path in the option") { Seq("TEMPORARY VIEW", "TABLE").foreach { tableType => val tableName = "relationProvierWithSchema" withTable(tableName) { From 7a807384b5814b7458ff534ff6146b842f00a8d4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 20 Sep 2016 15:04:18 -0700 Subject: [PATCH 7/7] address comments --- .../spark/sql/sources/TableScanSuite.scala | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index ee40cddf0280d..86bcb4d4b00c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -345,24 +345,6 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { (1 to 10).map(Row(_)).toSeq) } - test("create a temp view or a persistent table that does not need a path in the option") { - Seq("TEMPORARY VIEW", "TABLE").foreach { tableType => - val tableName = "relationProvierWithSchema" - withTable(tableName) { - sql( - s""" - |CREATE $tableType $tableName - |USING org.apache.spark.sql.sources.SimpleScanSource - |OPTIONS ( - | From '1', - | To '10' - |) - """.stripMargin) - checkAnswer(spark.table(tableName), spark.range(1, 11).toDF()) - } - } - } - test("exceptions") { // Make sure we do throw correct exception when users use a relation provider that // only implements the RelationProvider or the SchemaRelationProvider.