From 305418c8d7b51de0b20fb55e7c8baf652ca06284 Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Wed, 13 May 2015 09:21:26 -0700 Subject: [PATCH 01/12] orc data source support --- .../sql/hive/orc/HadoopTypeConverter.scala | 93 +++ .../spark/sql/hive/orc/OrcFileOperator.scala | 87 +++ .../spark/sql/hive/orc/OrcFilters.scala | 117 ++++ .../spark/sql/hive/orc/OrcRelation.scala | 176 ++++++ .../sql/hive/orc/OrcTableOperations.scala | 116 ++++ .../apache/spark/sql/hive/orc/package.scala | 61 ++ .../hive/orc/OrcPartitionDiscoverySuite.scala | 278 +++++++++ .../spark/sql/hive/orc/OrcQuerySuite.scala | 211 +++++++ .../spark/sql/hive/orc/OrcRelationTest.scala | 533 ++++++++++++++++++ .../apache/spark/sql/hive/orc/OrcSuite.scala | 211 +++++++ 10 files changed, 1883 insertions(+) create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationTest.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala new file mode 100644 index 0000000000000..aabc5477b05a6 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.hive.common.`type`.HiveVarchar +import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.spark.sql.catalyst.expressions.{Row, MutableRow} + +import scala.collection.JavaConversions._ + +/** + * We can consolidate TableReader.unwrappers and HiveInspectors.wrapperFor to use + * this class. + * + */ +private[hive] object HadoopTypeConverter extends HiveInspectors { + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ + def unwrappers(fieldRefs: Seq[StructField]): Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { + _.getFieldObjectInspector match { + case oi: BooleanObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + case oi: ByteObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + case oi: ShortObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + case oi: IntObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + case oi: LongObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + case oi: FloatObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + case oi: DoubleObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi => + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) + } + } + + /** + * Wraps with Hive types based on object inspector. + */ + def wrappers(oi: ObjectInspector): Any => Any = oi match { + case _: JavaHiveVarcharObjectInspector => + (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + + case _: JavaHiveDecimalObjectInspector => + (o: Any) => HiveShim.createDecimal(o.asInstanceOf[BigDecimal].underlying()) + + case soi: StandardStructObjectInspector => + val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + (o: Any) => { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row].toSeq).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct + } + + case loi: ListObjectInspector => + val wrapper = wrapperFor(loi.getListElementObjectInspector) + (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) + + case moi: MapObjectInspector => + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) + (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) + + case _ => + identity[Any] + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala new file mode 100644 index 0000000000000..6805677e84ea9 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.IOException + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.hive.HiveMetastoreTypes +import org.apache.spark.sql.types.StructType + +private[orc] object OrcFileOperator extends Logging{ + + def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = { + var conf = config.getOrElse(new Configuration) + val fspath = new Path(pathStr) + val fs = fspath.getFileSystem(conf) + val orcFiles = listOrcFiles(pathStr, conf) + OrcFile.createReader(fs, orcFiles(0)) + } + + def readSchema(path: String, conf: Option[Configuration]): StructType = { + val reader = getFileReader(path, conf) + val readerInspector: StructObjectInspector = reader.getObjectInspector + .asInstanceOf[StructObjectInspector] + val schema = readerInspector.getTypeName + HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] + } + + def getObjectInspector(path: String, conf: Option[Configuration]): StructObjectInspector = { + val reader = getFileReader(path, conf) + val readerInspector: StructObjectInspector = reader.getObjectInspector + .asInstanceOf[StructObjectInspector] + readerInspector + } + + def deletePath(pathStr: String, conf: Configuration): Unit = { + val fspath = new Path(pathStr) + val fs = fspath.getFileSystem(conf) + try { + fs.delete(fspath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${fspath.toString} prior" + + s" to InsertIntoOrcTable:\n${e.toString}") + } + } + + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + val path = origPath.makeQualified(fs) + val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) + .filterNot(_.isDir) + .map(_.getPath) + .filterNot(_.getName.startsWith("_")) + .filterNot(_.getName.startsWith(".")) + + if (paths == null || paths.size == 0) { + throw new IllegalArgumentException( + s"orcFileOperator: path $path does not have valid orc files matching the pattern") + } + logInfo("Qualified file list: ") + paths.foreach{x=>logInfo(x.toString)} + paths + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala new file mode 100644 index 0000000000000..8e73f9181f08b --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder +import org.apache.spark.Logging +import org.apache.spark.sql.sources._ + +private[sql] object OrcFilters extends Logging { + + def createFilter(expr: Array[Filter]): Option[SearchArgument] = { + if (expr == null || expr.size == 0) return None + var sarg: Option[Builder] = Some(SearchArgument.FACTORY.newBuilder()) + sarg.get.startAnd() + expr.foreach { + x => { + sarg match { + case Some(s1) => sarg = createFilter(x, s1) + case _ => None + } + } + } + sarg match { + case Some(b) => Some(b.end.build) + case _ => None + } + } + + def createFilter(expression: Filter, builder: Builder): Option[Builder] = { + expression match { + case p@And(left: Filter, right: Filter) => { + val b1 = builder.startAnd() + val b2 = createFilter(left, b1) + b2 match { + case Some(b) => val b3 = createFilter(right, b) + if (b3.isDefined) { + Some(b3.get.end) + } else { + None + } + case _ => None + } + } + case p@Or(left: Filter, right: Filter) => { + val b1 = builder.startOr() + val b2 = createFilter(left, b1) + b2 match { + case Some(b) => val b3 = createFilter(right, b) + if (b3.isDefined) { + Some(b3.get.end) + } else { + None + } + case _ => None + } + } + case p@Not(child: Filter) => { + val b1 = builder.startNot() + val b2 = createFilter(child, b1) + b2 match { + case Some(b) => Some(b.end) + case _ => None + } + } + case p@EqualTo(attribute: String, value: Any) => { + val b1 = builder.equals(attribute, value) + Some(b1) + } + case p@LessThan(attribute: String, value: Any) => { + val b1 = builder.lessThan(attribute ,value) + Some(b1) + } + case p@LessThanOrEqual(attribute: String, value: Any) => { + val b1 = builder.lessThanEquals(attribute, value) + Some(b1) + } + case p@GreaterThan(attribute: String, value: Any) => { + val b1 = builder.startNot().lessThanEquals(attribute, value).end() + Some(b1) + } + case p@GreaterThanOrEqual(attribute: String, value: Any) => { + val b1 = builder.startNot().lessThan(attribute, value).end() + Some(b1) + } + case p@IsNull(attribute: String) => { + val b1 = builder.startNot().isNull(attribute).end() + Some(b1) + } + case p@In(attribute: String, values: Array[Any]) => { + val b1 = builder.in(attribute, values) + Some(b1) + } + // not supported in filter + // case p@EqualNullSafe(left: String, right: String) => { + // val b1 = builder.nullSafeEquals(left, right) + // Some(b1) + // } + case _ => None + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala new file mode 100644 index 0000000000000..816f3794a6a02 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.util.Objects + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.io.orc.{OrcSerde, OrcOutputFormat} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StructObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfoUtils, TypeInfo} +import org.apache.hadoop.io.{Writable, NullWritable} +import org.apache.hadoop.mapred.{RecordWriter, Reporter, JobConf} +import org.apache.hadoop.mapreduce.{TaskID, TaskAttemptContext} +import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.hive.HiveMetastoreTypes +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.sources._ +import scala.collection.JavaConversions._ + + +private[sql] class DefaultSource extends FSBasedRelationProvider { + + def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): FSBasedRelation ={ + val partitionSpec = partitionColumns.map(PartitionSpec(_, Seq.empty[Partition])) + OrcRelation(paths, parameters, + schema, partitionSpec)(sqlContext) + } +} + + +private[sql] class OrcOutputWriter extends OutputWriter with SparkHadoopMapRedUtil { + var recordWriter: RecordWriter[NullWritable, Writable] = _ + var taskAttemptContext: TaskAttemptContext = _ + var serializer: OrcSerde = _ + var wrappers: Array[Any => Any] = _ + var created = false + var path: String = _ + var dataSchema: StructType = _ + var fieldOIs: Array[ObjectInspector] = _ + var standardOI: StructObjectInspector = _ + + + override def init(path: String, + dataSchema: StructType, + context: TaskAttemptContext): Unit = { + this.path = path + this.dataSchema = dataSchema + taskAttemptContext = context + } + + // Avoid create empty file without schema attached + private def initWriter() = { + if (!created) { + created = true + val conf = taskAttemptContext.getConfiguration + val outputFormat = new OrcOutputFormat() + val taskId: TaskID = taskAttemptContext.getTaskAttemptID.getTaskID + val partition: Int = taskId.getId + val filename = s"part-r-${partition}-${System.currentTimeMillis}.orc" + val file = new Path(path, filename) + val fs = file.getFileSystem(conf) + val orcSchema = HiveMetastoreTypes.toMetastoreType(dataSchema) + + serializer = new OrcSerde + val typeInfo: TypeInfo = + TypeInfoUtils.getTypeInfoFromTypeString(orcSchema) + standardOI = TypeInfoUtils + .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) + .asInstanceOf[StructObjectInspector] + fieldOIs = standardOI + .getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + wrappers = fieldOIs.map(HadoopTypeConverter.wrappers) + recordWriter = { + outputFormat.getRecordWriter(fs, + conf.asInstanceOf[JobConf], + file.toUri.getPath, Reporter.NULL) + .asInstanceOf[org.apache.hadoop.mapred.RecordWriter[NullWritable, Writable]] + } + } + } + override def write(row: Row): Unit = { + initWriter() + var i = 0 + val outputData = new Array[Any](fieldOIs.length) + while (i < row.length) { + outputData(i) = wrappers(i)(row(i)) + i += 1 + } + val writable = serializer.serialize(outputData, standardOI) + recordWriter.write(NullWritable.get(), writable) + } + + override def close(): Unit = { + if (recordWriter != null) { + recordWriter.close(Reporter.NULL) + } + } +} + + +@DeveloperApi +private[sql] case class OrcRelation(override val paths: Array[String], + parameters: Map[String, String], + maybeSchema: Option[StructType] = None, + maybePartitionSpec: Option[PartitionSpec] = None)( + @transient val sqlContext: SQLContext) + extends FSBasedRelation(paths, maybePartitionSpec) + with Logging { + self: Product => + @transient val conf = sqlContext.sparkContext.hadoopConfiguration + + + override def dataSchema: StructType = + maybeSchema.getOrElse(OrcFileOperator.readSchema(paths(0), Some(conf))) + + override def outputWriterClass: Class[_ <: OutputWriter] = classOf[OrcOutputWriter] + /** Attributes */ + var output: Seq[Attribute] = schema.toAttributes + + override def needConversion: Boolean = false + + // Equals must also take into account the output attributes so that we can distinguish between + // different instances of the same relation, + override def equals(other: Any): Boolean = other match { + case that: OrcRelation => + paths.toSet == that.paths.toSet && + dataSchema == that.dataSchema && + schema == that.schema && + partitionColumns == that.partitionColumns + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode( + paths.toSet, + dataSchema, + schema, + maybePartitionSpec) + } + override def buildScan(requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[String]): RDD[Row] = { + val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes + OrcTableScan(output, this, filters, inputPaths).execute() + } +} + +private[sql] object OrcRelation extends Logging { + // Default partition name to use when the partition column value is null or empty string. + val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala new file mode 100644 index 0000000000000..94c78a14524b5 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.util._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.io.orc._ +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.sql.sources.Filter +import org.apache.spark.{Logging, SerializableWritable} +import scala.collection.JavaConversions._ + +case class OrcTableScan(attributes: Seq[Attribute], + @transient relation: OrcRelation, + filters: Array[Filter], + inputPaths: Array[String]) extends Logging { + @transient val sqlContext = relation.sqlContext + val path = relation.paths(0) + + def addColumnIds(output: Seq[Attribute], + relation: OrcRelation, conf: Configuration) { + val ids = + output.map(a => + relation.dataSchema.toAttributes.indexWhere(_.name == a.name): Integer) + .filter(_ >= 0) + val names = attributes.map(_.name) + val sorted = ids.zip(names).sorted + HiveShim.appendReadColumns(conf, sorted.map(_._1), sorted.map(_._2)) + } + + def buildFilter(job: Job, filters: Array[Filter]): Unit = { + if (ORC_FILTER_PUSHDOWN_ENABLED) { + val conf: Configuration = job.getConfiguration + val recordFilter = OrcFilters.createFilter(filters) + if (recordFilter.isDefined) { + conf.set(SARG_PUSHDOWN, toKryo(recordFilter.get)) + conf.setBoolean(INDEX_FILTER, true) + } + } + } + + // Transform all given raw `Writable`s into `Row`s. + def fillObject(conf: Configuration, + iterator: Iterator[org.apache.hadoop.io.Writable], + nonPartitionKeyAttrs: Seq[(Attribute, Int)], + mutableRow: MutableRow): Iterator[Row] = { + val deserializer = new OrcSerde + val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { + case (attr, ordinal) => + soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal + }.unzip + val unwrappers = HadoopTypeConverter.unwrappers(fieldRefs) + // Map each tuple to a row object + iterator.map { value => + val raw = deserializer.deserialize(value) + logDebug("Raw data: " + raw) + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 + } + mutableRow: Row + } + } + + def execute(): RDD[Row] = { + val sc = sqlContext.sparkContext + val job = new Job(sc.hadoopConfiguration) + val conf: Configuration = job.getConfiguration + + buildFilter(job, filters) + addColumnIds(attributes, relation, conf) + FileInputFormat.setInputPaths(job, inputPaths.map(new Path(_)): _*) + + val inputClass = classOf[OrcInputFormat].asInstanceOf[ + Class[_ <: org.apache.hadoop.mapred.InputFormat[NullWritable, Writable]]] + + val rdd = sc.hadoopRDD(conf.asInstanceOf[JobConf], + inputClass, classOf[NullWritable], classOf[Writable]).map(_._2) + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val wrappedConf = new SerializableWritable(conf) + val rowRdd: RDD[Row] = rdd.mapPartitions { iter => + fillObject(wrappedConf.value, iter, attributes.zipWithIndex, mutableRow) + } + rowRdd + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala new file mode 100644 index 0000000000000..a85f035a9424d --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import com.esotericsoftware.kryo.io.Output +import com.esotericsoftware.kryo.Kryo +import org.apache.commons.codec.binary.Base64 +import org.apache.spark.sql.{SaveMode, DataFrame} +import scala.reflect.runtime.universe.{TypeTag, typeTag} + +package object orc { + implicit class OrcContext(sqlContext: HiveContext) { + import sqlContext._ + @scala.annotation.varargs + def orcFile(paths: String*): DataFrame = { + if (paths.isEmpty) { + emptyDataFrame + } else { + val orcRelation = OrcRelation(paths.toArray, Map.empty)(sqlContext) + sqlContext.baseRelationToDataFrame(orcRelation) + } + } + } + + implicit class OrcSchemaRDD(dataFrame: DataFrame) { + def saveAsOrcFile(path: String, mode: SaveMode = SaveMode.Overwrite): Unit = { + dataFrame.save( + path, + source = classOf[DefaultSource].getCanonicalName, + mode = SaveMode.Overwrite) + } + } + + // Flags for orc copression, predicates pushdown, etc. + val orcDefaultCompressVar = "hive.exec.orc.default.compress" + var ORC_FILTER_PUSHDOWN_ENABLED = true + val SARG_PUSHDOWN = "sarg.pushdown"; + val INDEX_FILTER = "hive.optimize.index.filter" + + def toKryo(input: Any): String = { + val out = new Output(4 * 1024, 10 * 1024 * 1024); + new Kryo().writeObject(out, input); + out.close(); + Base64.encodeBase64String(out.toBytes()); + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala new file mode 100644 index 0000000000000..b8fe582498ecf --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.util.Utils +import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + + +// The data where the partitioning key exists only in the directory structure. +case class OrcParData(intField: Int, stringField: String) + +// The data that also includes the partitioning key +case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) + +class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { + val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" + + def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + def makeOrcFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().saveAsOrcFile(path.getCanonicalPath) + } + + + def makeOrcFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.saveAsOrcFile(path.getCanonicalPath) + } + + protected def withTempTable(tableName: String)(f: => Unit): Unit = { + try f finally TestHive.dropTempTable(tableName) + } + + protected def makePartitionDir( + basePath: File, + defaultPartitionName: String, + partitionCols: (String, Any)*): File = { + val partNames = partitionCols.map { case (k, v) => + val valueString = if (v == null || v == "") defaultPartitionName else v.toString + s"$k=$valueString" + } + + val partDir = partNames.foldLeft(basePath) { (parent, child) => + new File(parent, child) + } + + assert(partDir.mkdirs(), s"Couldn't create directory $partDir") + partDir + } + + test("read partitioned table - normal case") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeOrcFile( + (1 to 10).map(i => OrcParData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + TestHive.orcFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, 1, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, i.toString, pi, "foo")) + } + } + } + + test("read partitioned table - partition key included in orc file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeOrcFile( + (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + TestHive.orcFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, 1, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, "foo")) + } + } + } + + + test("read partitioned table - with nulls") { + withTempDir { base => + for { + // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero... + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeOrcFile( + (1 to 10).map(i => OrcParData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val orcRelation = load( + "org.apache.spark.sql.hive.orc.DefaultSource", + Map( + "path" -> base.getCanonicalPath, + OrcRelation.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + + orcRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi IS NULL"), + for { + i <- 1 to 10 + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, null, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + } yield Row(i, i.toString, pi, null)) + } + } + } + + test("read partitioned table - with nulls and partition keys are included in Orc file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeOrcFile( + (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val orcRelation = load( + "org.apache.spark.sql.hive.orc.DefaultSource", + Map( + "path" -> base.getCanonicalPath, + OrcRelation.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + + orcRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, null)) + } + } + } + + ignore("read partitioned table - merging compatible schemas: not supported yet") { + withTempDir { base => + makeOrcFile( + (1 to 10).map(i => Tuple1(i)).toDF("intField"), + makePartitionDir(base, defaultPartitionName, "pi" -> 1)) + + makeOrcFile( + (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), + makePartitionDir(base, defaultPartitionName, "pi" -> 2)) + + load(base.getCanonicalPath, "org.apache.spark.sql.hive.orc").registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + (1 to 10).map(i => Row(i, null, 1)) ++ (1 to 10).map(i => Row(i, i.toString, 2))) + } + } + } +} + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala new file mode 100644 index 0000000000000..90490df765ca6 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import org.apache.hadoop.hive.ql.io.orc.CompressionKind +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.util.Utils +import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} + +case class TestRDDEntry(key: Int, value: String) + +case class NullReflectData( + intField: java.lang.Integer, + longField: java.lang.Long, + floatField: java.lang.Float, + doubleField: java.lang.Double, + booleanField: java.lang.Boolean) + +case class OptionalReflectData( + intField: Option[Int], + longField: Option[Long], + floatField: Option[Float], + doubleField: Option[Double], + booleanField: Option[Boolean]) + +case class Nested(i: Int, s: String) + +case class Data(array: Seq[Int], nested: Nested) + +case class AllDataTypes( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean) + +case class AllDataTypesWithNonPrimitiveType( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean, + array: Seq[Int], + arrayContainsNull: Seq[Option[Int]], + map: Map[Int, Long], + mapValueContainsNull: Map[Int, Option[Long]], + data: Data) + +case class BinaryData(binaryData: Array[Byte]) + +case class Contact(name: String, phone: String) + +case class Person(name: String, age: Int, contacts: Seq[Contact]) + +class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { + + def getTempFilePath(prefix: String, suffix: String = ""): File = { + val tempFile = File.createTempFile(prefix, suffix) + tempFile.delete() + tempFile + } + + test("Read/Write All Types") { + val tempDir = getTempFilePath("orcTest").getCanonicalPath + val range = (0 to 255) + val data = sparkContext.parallelize(range) + .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + data.toDF().saveAsOrcFile(tempDir) + checkAnswer( + TestHive.orcFile(tempDir), + data.toDF().collect().toSeq) + Utils.deleteRecursively(new File(tempDir)) + } + + test("read/write binary data") { + val tempDir = getTempFilePath("orcTest").getCanonicalPath + sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).toDF().saveAsOrcFile(tempDir) + TestHive.orcFile(tempDir) + .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) + .collect().toSeq == Seq("test") + Utils.deleteRecursively(new File(tempDir)) + } + + test("Read/Write All Types with non-primitive type") { + val tempDir = getTempFilePath("orcTest").getCanonicalPath + val range = (0 to 255) + val data = sparkContext.parallelize(range) + .map(x => AllDataTypesWithNonPrimitiveType( + s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, + (0 until x), + (0 until x).map(Option(_).filter(_ % 3 == 0)), + (0 until x).map(i => i -> i.toLong).toMap, + (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), + Data((0 until x), Nested(x, s"$x")))) + data.toDF().saveAsOrcFile(tempDir) + + checkAnswer( + TestHive.orcFile(tempDir), + data.toDF().collect().toSeq) + Utils.deleteRecursively(new File(tempDir)) + } + + test("Creating case class RDD table") { + sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + .toDF().registerTempTable("tmp") + val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) + var counter = 1 + rdd.foreach { + // '===' does not like string comparison? + row: Row => { + assert(row.getString(1).equals(s"val_$counter"), s"row $counter value ${row.getString(1)} does not match val_$counter") + counter = counter + 1 + } + } + } + + test("Simple selection form orc table") { + val tempDir = getTempFilePath("orcTest").getCanonicalPath + val data = sparkContext.parallelize((1 to 10)) + .map(i => Person(s"name_$i", i, (0 until 2).map{ m=> + Contact(s"contact_$m", s"phone_$m") })) + data.toDF().saveAsOrcFile(tempDir) + val f = TestHive.orcFile(tempDir) + f.registerTempTable("tmp") + var rdd = sql("SELECT name FROM tmp where age <= 5") + assert(rdd.count() == 5) + + rdd = sql("SELECT name, contacts FROM tmp where age > 5") + assert(rdd.count() == 5) + val contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) + assert(contacts.count() == 10) + Utils.deleteRecursively(new File(tempDir)) + } + + test("save and load case class RDD with Nones as orc") { + val data = OptionalReflectData(None, None, None, None, None) + val rdd = sparkContext.parallelize(data :: Nil) + val tempDir = getTempFilePath("orcTest").getCanonicalPath + rdd.toDF().saveAsOrcFile(tempDir) + val readFile = TestHive.orcFile(tempDir) + val rdd_saved = readFile.collect() + assert(rdd_saved(0).toSeq === Seq.fill(5)(null)) + Utils.deleteRecursively(new File(tempDir)) + } + + // We only support zlib in hive0.12.0 now + test("Default Compression options for writing to an Orcfile") { + //TODO: support other compress codec + var tempDir = getTempFilePath("orcTest").getCanonicalPath + val rdd = sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + rdd.toDF().saveAsOrcFile(tempDir) + var actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression + assert(actualCodec == CompressionKind.ZLIB) + Utils.deleteRecursively(new File(tempDir)) + } + + // Following codec is supported in hive-0.13.1, ignore it now + ignore("Other Compression options for writing to an Orcfile only supported in hive 0.13.1 and above") { + TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "SNAPPY") + var tempDir = getTempFilePath("orcTest").getCanonicalPath + val rdd = sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + rdd.toDF().saveAsOrcFile(tempDir) + var actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression + assert(actualCodec == CompressionKind.SNAPPY) + Utils.deleteRecursively(new File(tempDir)) + + TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "NONE") + tempDir = getTempFilePath("orcTest").getCanonicalPath + rdd.toDF().saveAsOrcFile(tempDir) + actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression + assert(actualCodec == CompressionKind.NONE) + Utils.deleteRecursively(new File(tempDir)) + + TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "LZO") + tempDir = getTempFilePath("orcTest").getCanonicalPath + rdd.toDF().saveAsOrcFile(tempDir) + actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression + assert(actualCodec == CompressionKind.LZO) + Utils.deleteRecursively(new File(tempDir)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationTest.scala new file mode 100644 index 0000000000000..8ed8c53e81613 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationTest.scala @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.parquet.ParquetTest +import org.apache.spark.sql.sources.{FSBasedRelation, LogicalRelation} +import org.apache.spark.sql.types._ + +// TODO Don't extend ParquetTest +// This test suite extends ParquetTest for some convenient utility methods. These methods should be +// moved to some more general places, maybe QueryTest. +class OrcRelationTest extends QueryTest with ParquetTest { + override val sqlContext: SQLContext = TestHive + + import sqlContext._ + import sqlContext.implicits._ + + val dataSourceName = classOf[DefaultSource].getCanonicalName + + val dataSchema = + StructType( + Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = false))) + + val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") + + val partitionedTestDF1 = (for { + i <- 1 to 3 + p2 <- Seq("foo", "bar") + } yield (i, s"val_$i", 1, p2)).toDF("a", "b", "p1", "p2") + + val partitionedTestDF2 = (for { + i <- 1 to 3 + p2 <- Seq("foo", "bar") + } yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2") + + val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) + + def checkQueries(df: DataFrame): Unit = { + // Selects everything + checkAnswer( + df, + for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2)) + + // Simple filtering and partition pruning + checkAnswer( + df.filter('a > 1 && 'p1 === 2), + for (i <- 2 to 3; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", 2, p2)) + + // Simple projection and filtering + checkAnswer( + df.filter('a > 1).select('b, 'a + 1), + for (i <- 2 to 3; _ <- 1 to 2; _ <- Seq("foo", "bar")) yield Row(s"val_$i", i + 1)) + + // Simple projection and partition pruning + checkAnswer( + df.filter('a > 1 && 'p1 < 2).select('b, 'p1), + for (i <- 2 to 3; _ <- Seq("foo", "bar")) yield Row(s"val_$i", 1)) + + // Self-join + df.registerTempTable("t") + withTempTable("t") { + checkAnswer( + sql( + """SELECT l.a, r.b, l.p1, r.p2 + |FROM t l JOIN t r + |ON l.a = r.a AND l.p1 = r.p1 AND l.p2 = r.p2 + """.stripMargin), + for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2)) + } + } + + test("save()/load() - non-partitioned table - Overwrite") { + withTempPath { file => + testDF.save( + path = file.getCanonicalPath, + source = dataSourceName, + mode = SaveMode.Overwrite) + + testDF.save( + path = file.getCanonicalPath, + source = dataSourceName, + mode = SaveMode.Overwrite) + + checkAnswer( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchema.json)), + testDF.collect()) + } + } + + test("save()/load() - non-partitioned table - Append") { + withTempPath { file => + testDF.save( + path = file.getCanonicalPath, + source = dataSourceName, + mode = SaveMode.Overwrite) + + testDF.save( + path = file.getCanonicalPath, + source = dataSourceName, + mode = SaveMode.Append) + + checkAnswer( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchema.json)).orderBy("a"), + testDF.unionAll(testDF).orderBy("a").collect()) + } + } + + test("save()/load() - non-partitioned table - ErrorIfExists") { + withTempDir { file => + intercept[RuntimeException] { + testDF.save( + path = file.getCanonicalPath, + source = dataSourceName, + mode = SaveMode.ErrorIfExists) + } + } + } + + test("save()/load() - non-partitioned table - Ignore") { + withTempDir { file => + testDF.save( + path = file.getCanonicalPath, + source = dataSourceName, + mode = SaveMode.Ignore) + + val path = new Path(file.getCanonicalPath) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + assert(fs.listStatus(path).isEmpty) + } + } + + test("save()/load() - partitioned table - simple queries") { + withTempPath { file => + partitionedTestDF.save( + source = dataSourceName, + mode = SaveMode.ErrorIfExists, + options = Map("path" -> file.getCanonicalPath), + partitionColumns = Seq("p1", "p2")) + + checkQueries( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchema.json))) + } + } + + test("save()/load() - partitioned table - Overwrite") { + withTempPath { file => + partitionedTestDF.save( + source = dataSourceName, + mode = SaveMode.Overwrite, + options = Map("path" -> file.getCanonicalPath), + partitionColumns = Seq("p1", "p2")) + + partitionedTestDF.save( + source = dataSourceName, + mode = SaveMode.Overwrite, + options = Map("path" -> file.getCanonicalPath), + partitionColumns = Seq("p1", "p2")) + + checkAnswer( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchema.json)), + partitionedTestDF.collect()) + } + } + + test("save()/load() - partitioned table - Append") { + withTempPath { file => + partitionedTestDF.save( + source = dataSourceName, + mode = SaveMode.Overwrite, + options = Map("path" -> file.getCanonicalPath), + partitionColumns = Seq("p1", "p2")) + + partitionedTestDF.save( + source = dataSourceName, + mode = SaveMode.Append, + options = Map("path" -> file.getCanonicalPath), + partitionColumns = Seq("p1", "p2")) + + checkAnswer( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchema.json)), + partitionedTestDF.unionAll(partitionedTestDF).collect()) + } + } + + test("save()/load() - partitioned table - Append - new partition values") { + withTempPath { file => + partitionedTestDF1.save( + source = dataSourceName, + mode = SaveMode.Overwrite, + options = Map("path" -> file.getCanonicalPath), + partitionColumns = Seq("p1", "p2")) + + partitionedTestDF2.save( + source = dataSourceName, + mode = SaveMode.Append, + options = Map("path" -> file.getCanonicalPath), + partitionColumns = Seq("p1", "p2")) + + checkAnswer( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchema.json)), + partitionedTestDF.collect()) + } + } + + test("save()/load() - partitioned table - ErrorIfExists") { + withTempDir { file => + intercept[RuntimeException] { + partitionedTestDF.save( + source = dataSourceName, + mode = SaveMode.ErrorIfExists, + options = Map("path" -> file.getCanonicalPath), + partitionColumns = Seq("p1", "p2")) + } + } + } + + test("save()/load() - partitioned table - Ignore") { + withTempDir { file => + partitionedTestDF.save( + path = file.getCanonicalPath, + source = dataSourceName, + mode = SaveMode.Ignore) + + val path = new Path(file.getCanonicalPath) + val fs = path.getFileSystem(SparkHadoopUtil.get.conf) + assert(fs.listStatus(path).isEmpty) + } + } + + def withTable(tableName: String)(f: => Unit): Unit = { + try f finally sql(s"DROP TABLE $tableName") + } + + test("saveAsTable()/load() - non-partitioned table - Overwrite") { + testDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Overwrite, + Map("dataSchema" -> dataSchema.json)) + + withTable("t") { + checkAnswer(table("t"), testDF.collect()) + } + } + + test("saveAsTable()/load() - non-partitioned table - Append") { + testDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Overwrite) + + testDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Append) + + withTable("t") { + checkAnswer(table("t"), testDF.unionAll(testDF).orderBy("a").collect()) + } + } + + test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + intercept[AnalysisException] { + testDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.ErrorIfExists) + } + } + } + + test("saveAsTable()/load() - non-partitioned table - Ignore") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + testDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Ignore) + + assert(table("t").collect().isEmpty) + } + } + + test("saveAsTable()/load() - partitioned table - simple queries") { + partitionedTestDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Overwrite, + Map("dataSchema" -> dataSchema.json)) + + withTable("t") { + checkQueries(table("t")) + } + } + + test("saveAsTable()/load() - partitioned table - Overwrite") { + partitionedTestDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Overwrite, + options = Map("dataSchema" -> dataSchema.json), + partitionColumns = Seq("p1", "p2")) + + partitionedTestDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Overwrite, + options = Map("dataSchema" -> dataSchema.json), + partitionColumns = Seq("p1", "p2")) + + withTable("t") { + checkAnswer(table("t"), partitionedTestDF.collect()) + } + } + + test("saveAsTable()/load() - partitioned table - Append") { + partitionedTestDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Overwrite, + options = Map("dataSchema" -> dataSchema.json), + partitionColumns = Seq("p1", "p2")) + + partitionedTestDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Append, + options = Map("dataSchema" -> dataSchema.json), + partitionColumns = Seq("p1", "p2")) + + withTable("t") { + checkAnswer(table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) + } + } + + test("saveAsTable()/load() - partitioned table - Append - new partition values") { + partitionedTestDF1.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Overwrite, + options = Map("dataSchema" -> dataSchema.json), + partitionColumns = Seq("p1", "p2")) + + partitionedTestDF2.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Append, + options = Map("dataSchema" -> dataSchema.json), + partitionColumns = Seq("p1", "p2")) + + withTable("t") { + checkAnswer(table("t"), partitionedTestDF.collect()) + } + } + + test("saveAsTable()/load() - partitioned table - Append - mismatched partition columns") { + partitionedTestDF1.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Overwrite, + options = Map("dataSchema" -> dataSchema.json), + partitionColumns = Seq("p1", "p2")) + + // Using only a subset of all partition columns + intercept[Throwable] { + partitionedTestDF2.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Append, + options = Map("dataSchema" -> dataSchema.json), + partitionColumns = Seq("p1")) + } + + // Using different order of partition columns + intercept[Throwable] { + partitionedTestDF2.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Append, + options = Map("dataSchema" -> dataSchema.json), + partitionColumns = Seq("p2", "p1")) + } + } + + test("saveAsTable()/load() - partitioned table - ErrorIfExists") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + intercept[AnalysisException] { + partitionedTestDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.ErrorIfExists, + options = Map("dataSchema" -> dataSchema.json), + partitionColumns = Seq("p1", "p2")) + } + } + } + + test("saveAsTable()/load() - partitioned table - Ignore") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + partitionedTestDF.saveAsTable( + tableName = "t", + source = dataSourceName, + mode = SaveMode.Ignore, + options = Map("dataSchema" -> dataSchema.json), + partitionColumns = Seq("p1", "p2")) + + assert(table("t").collect().isEmpty) + } + } + + test("Hadoop style globbing") { + withTempPath { file => + partitionedTestDF.save( + source = dataSourceName, + mode = SaveMode.Overwrite, + options = Map("path" -> file.getCanonicalPath), + partitionColumns = Seq("p1", "p2")) + + val df = load( + source = dataSourceName, + options = Map( + "path" -> s"${file.getCanonicalPath}/p1=*/p2=???", + "dataSchema" -> dataSchema.json)) + + val expectedPaths = Set( + s"${file.getCanonicalFile}/p1=1/p2=foo", + s"${file.getCanonicalFile}/p1=2/p2=foo", + s"${file.getCanonicalFile}/p1=1/p2=bar", + s"${file.getCanonicalFile}/p1=2/p2=bar" + ).map { p => + val path = new Path(p) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + path.makeQualified(fs.getUri, fs.getWorkingDirectory).toString + } + val actualPaths = df.queryExecution.analyzed.collectFirst { + case LogicalRelation(relation: FSBasedRelation) => + relation.paths.toSet + }.getOrElse { + fail("Expect an FSBasedRelation, but none could be found") + } + + assert(actualPaths === expectedPaths) + checkAnswer(df, partitionedTestDF.collect()) + } + } +} + +class FSBasedOrcRelationSuite extends OrcRelationTest { + override val dataSourceName: String = classOf[DefaultSource].getCanonicalName + + import sqlContext._ + import sqlContext.implicits._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .saveAsOrcFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchemaWithPartition.json))) + } + } +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSuite.scala new file mode 100644 index 0000000000000..8e0252e971105 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSuite.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.hive.test.TestHive._ + +case class OrcData(intField: Int, stringField: String) + +abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { + var orcTableDir: File = null + var orcTableAsDir: File = null + + override def beforeAll(): Unit = { + super.beforeAll() + + orcTableAsDir = File.createTempFile("orctests", "sparksql") + orcTableAsDir.delete() + orcTableAsDir.mkdir() + + // Hack: to prepare orc data files using hive external tables + orcTableDir = File.createTempFile("orctests", "sparksql") + orcTableDir.delete() + orcTableDir.mkdir() + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + (sparkContext + .makeRDD(1 to 10) + .map(i => OrcData(i, s"part-$i"))) + .toDF.registerTempTable(s"orc_temp_table") + + sql(s""" + create external table normal_orc + ( + intField INT, + stringField STRING + ) + STORED AS orc + location '${orcTableDir.getCanonicalPath}' + """) + + sql( + s"""insert into table normal_orc + select intField, stringField from orc_temp_table""") + + } + + override def afterAll(): Unit = { + orcTableDir.delete() + orcTableAsDir.delete() + } + + test("create temporary orc table") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + Row(1, "part-1") :: + Row(2, "part-2") :: + Row(3, "part-3") :: + Row(4, "part-4") :: + Row(5, "part-5") :: + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT * FROM normal_orc_source where intField > 5"), + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT count(intField), stringField FROM normal_orc_source group by stringField"), + Row(1, "part-1") :: + Row(1, "part-2") :: + Row(1, "part-3") :: + Row(1, "part-4") :: + Row(1, "part-5") :: + Row(1, "part-6") :: + Row(1, "part-7") :: + Row(1, "part-8") :: + Row(1, "part-9") :: + Row(1, "part-10") :: Nil + ) + + } + + test("create temporary orc table as") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + Row(1, "part-1") :: + Row(2, "part-2") :: + Row(3, "part-3") :: + Row(4, "part-4") :: + Row(5, "part-5") :: + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT * FROM normal_orc_source where intField > 5"), + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + + checkAnswer( + sql("SELECT count(intField), stringField FROM normal_orc_source group by stringField"), + Row(1, "part-1") :: + Row(1, "part-2") :: + Row(1, "part-3") :: + Row(1, "part-4") :: + Row(1, "part-5") :: + Row(1, "part-6") :: + Row(1, "part-7") :: + Row(1, "part-8") :: + Row(1, "part-9") :: + Row(1, "part-10") :: Nil + ) + + } + + test("appending insert") { + sql("insert into table normal_orc_source select * from orc_temp_table where intField > 5") + checkAnswer( + sql("select * from normal_orc_source"), + Row(1, "part-1") :: + Row(2, "part-2") :: + Row(3, "part-3") :: + Row(4, "part-4") :: + Row(5, "part-5") :: + Row(6, "part-6") :: + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(9, "part-9") :: + Row(10, "part-10") :: + Row(10, "part-10") :: Nil + ) + } + + test("overwrite insert") { + sql("insert overwrite table normal_orc_as_source select * from orc_temp_table where intField > 5") + checkAnswer( + sql("select * from normal_orc_as_source"), + Row(6, "part-6") :: + Row(7, "part-7") :: + Row(8, "part-8") :: + Row(9, "part-9") :: + Row(10, "part-10") :: Nil + ) + } +} + +class OrcSourceSuite extends OrcSuite { + override def beforeAll(): Unit = { + super.beforeAll() + + sql( s""" + create temporary table normal_orc_source + USING org.apache.spark.sql.hive.orc + OPTIONS ( + path '${new File(orcTableDir.getAbsolutePath).getCanonicalPath}' + ) + """) + + sql( s""" + create temporary table normal_orc_as_source + USING org.apache.spark.sql.hive.orc + OPTIONS ( + path '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + ) + as select * from orc_temp_table + """) + } +} From 4e61c168bed49003a95388db2bb0b047e9da67da Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Wed, 13 May 2015 12:17:02 -0700 Subject: [PATCH 02/12] minor change --- .../apache/spark/sql/hive/orc/OrcQuerySuite.scala | 13 ++++++++----- ...OrcRelationTest.scala => OrcRelationSuite.scala} | 2 +- .../orc/{OrcSuite.scala => OrcSourceSuite.scala} | 3 ++- 3 files changed, 11 insertions(+), 7 deletions(-) rename sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/{OrcRelationTest.scala => OrcRelationSuite.scala} (99%) rename sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/{OrcSuite.scala => OrcSourceSuite.scala} (98%) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 90490df765ca6..3c596d0654324 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -91,7 +91,8 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { val tempDir = getTempFilePath("orcTest").getCanonicalPath val range = (0 to 255) val data = sparkContext.parallelize(range) - .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + .map(x => + AllDataTypes(s"$x", x, x.toLong, x.toFloat,x.toDouble, x.toShort, x.toByte, x % 2 == 0)) data.toDF().saveAsOrcFile(tempDir) checkAnswer( TestHive.orcFile(tempDir), @@ -101,7 +102,8 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { test("read/write binary data") { val tempDir = getTempFilePath("orcTest").getCanonicalPath - sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).toDF().saveAsOrcFile(tempDir) + sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil) + .toDF().saveAsOrcFile(tempDir) TestHive.orcFile(tempDir) .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) .collect().toSeq == Seq("test") @@ -136,7 +138,8 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { rdd.foreach { // '===' does not like string comparison? row: Row => { - assert(row.getString(1).equals(s"val_$counter"), s"row $counter value ${row.getString(1)} does not match val_$counter") + assert(row.getString(1).equals(s"val_$counter"), + s"row $counter value ${row.getString(1)} does not match val_$counter") counter = counter + 1 } } @@ -173,7 +176,7 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { // We only support zlib in hive0.12.0 now test("Default Compression options for writing to an Orcfile") { - //TODO: support other compress codec + // TODO: support other compress codec var tempDir = getTempFilePath("orcTest").getCanonicalPath val rdd = sparkContext.parallelize((1 to 100)) .map(i => TestRDDEntry(i, s"val_$i")) @@ -184,7 +187,7 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { } // Following codec is supported in hive-0.13.1, ignore it now - ignore("Other Compression options for writing to an Orcfile only supported in hive 0.13.1 and above") { + ignore("Other Compression options for writing to an Orcfile - 0.13.1 and above") { TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "SNAPPY") var tempDir = getTempFilePath("orcTest").getCanonicalPath val rdd = sparkContext.parallelize((1 to 100)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala similarity index 99% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationTest.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala index 8ed8c53e81613..0b486068d4b0b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala @@ -530,4 +530,4 @@ class FSBasedOrcRelationSuite extends OrcRelationTest { "dataSchema" -> dataSchemaWithPartition.json))) } } -} \ No newline at end of file +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala similarity index 98% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 8e0252e971105..f86750bcfb6d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -175,7 +175,8 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { } test("overwrite insert") { - sql("insert overwrite table normal_orc_as_source select * from orc_temp_table where intField > 5") + sql("insert overwrite table normal_orc_as_source select * " + + "from orc_temp_table where intField > 5") checkAnswer( sql("select * from normal_orc_as_source"), Row(6, "part-6") :: From 7cc2c6431ee147c514d38c699be0af509a27095c Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Wed, 13 May 2015 16:12:49 -0700 Subject: [PATCH 03/12] predicate fix --- .../main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 8e73f9181f08b..0a9924b139a48 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -99,6 +99,10 @@ private[sql] object OrcFilters extends Logging { Some(b1) } case p@IsNull(attribute: String) => { + val b1 = builder.isNull(attribute) + Some(b1) + } + case p@IsNotNull(attribute: String) => { val b1 = builder.startNot().isNull(attribute).end() Some(b1) } From f95abfdf5eaaa31254fd48b9830c12a5a9b0addb Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Wed, 13 May 2015 17:01:06 -0700 Subject: [PATCH 04/12] reuse test suite --- .../hive/orc/OrcPartitionDiscoverySuite.scala | 20 - .../spark/sql/hive/orc/OrcRelationSuite.scala | 479 +----------------- 2 files changed, 2 insertions(+), 497 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index b8fe582498ecf..7ebf3c6eced26 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -254,25 +254,5 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before } } } - - ignore("read partitioned table - merging compatible schemas: not supported yet") { - withTempDir { base => - makeOrcFile( - (1 to 10).map(i => Tuple1(i)).toDF("intField"), - makePartitionDir(base, defaultPartitionName, "pi" -> 1)) - - makeOrcFile( - (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), - makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - - load(base.getCanonicalPath, "org.apache.spark.sql.hive.orc").registerTempTable("t") - - withTempTable("t") { - checkAnswer( - sql("SELECT * FROM t"), - (1 to 10).map(i => Row(i, null, 1)) ++ (1 to 10).map(i => Row(i, i.toString, 2))) - } - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala index 0b486068d4b0b..1d8c421b90678 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala @@ -20,486 +20,11 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetTest -import org.apache.spark.sql.sources.{FSBasedRelation, LogicalRelation} +import org.apache.spark.sql.sources.{FSBasedRelationTest} import org.apache.spark.sql.types._ -// TODO Don't extend ParquetTest -// This test suite extends ParquetTest for some convenient utility methods. These methods should be -// moved to some more general places, maybe QueryTest. -class OrcRelationTest extends QueryTest with ParquetTest { - override val sqlContext: SQLContext = TestHive - import sqlContext._ - import sqlContext.implicits._ - - val dataSourceName = classOf[DefaultSource].getCanonicalName - - val dataSchema = - StructType( - Seq( - StructField("a", IntegerType, nullable = false), - StructField("b", StringType, nullable = false))) - - val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") - - val partitionedTestDF1 = (for { - i <- 1 to 3 - p2 <- Seq("foo", "bar") - } yield (i, s"val_$i", 1, p2)).toDF("a", "b", "p1", "p2") - - val partitionedTestDF2 = (for { - i <- 1 to 3 - p2 <- Seq("foo", "bar") - } yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2") - - val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) - - def checkQueries(df: DataFrame): Unit = { - // Selects everything - checkAnswer( - df, - for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2)) - - // Simple filtering and partition pruning - checkAnswer( - df.filter('a > 1 && 'p1 === 2), - for (i <- 2 to 3; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", 2, p2)) - - // Simple projection and filtering - checkAnswer( - df.filter('a > 1).select('b, 'a + 1), - for (i <- 2 to 3; _ <- 1 to 2; _ <- Seq("foo", "bar")) yield Row(s"val_$i", i + 1)) - - // Simple projection and partition pruning - checkAnswer( - df.filter('a > 1 && 'p1 < 2).select('b, 'p1), - for (i <- 2 to 3; _ <- Seq("foo", "bar")) yield Row(s"val_$i", 1)) - - // Self-join - df.registerTempTable("t") - withTempTable("t") { - checkAnswer( - sql( - """SELECT l.a, r.b, l.p1, r.p2 - |FROM t l JOIN t r - |ON l.a = r.a AND l.p1 = r.p1 AND l.p2 = r.p2 - """.stripMargin), - for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2)) - } - } - - test("save()/load() - non-partitioned table - Overwrite") { - withTempPath { file => - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Overwrite) - - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Overwrite) - - checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), - testDF.collect()) - } - } - - test("save()/load() - non-partitioned table - Append") { - withTempPath { file => - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Overwrite) - - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Append) - - checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)).orderBy("a"), - testDF.unionAll(testDF).orderBy("a").collect()) - } - } - - test("save()/load() - non-partitioned table - ErrorIfExists") { - withTempDir { file => - intercept[RuntimeException] { - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.ErrorIfExists) - } - } - } - - test("save()/load() - non-partitioned table - Ignore") { - withTempDir { file => - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Ignore) - - val path = new Path(file.getCanonicalPath) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - assert(fs.listStatus(path).isEmpty) - } - } - - test("save()/load() - partitioned table - simple queries") { - withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.ErrorIfExists, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - checkQueries( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json))) - } - } - - test("save()/load() - partitioned table - Overwrite") { - withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), - partitionedTestDF.collect()) - } - } - - test("save()/load() - partitioned table - Append") { - withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Append, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), - partitionedTestDF.unionAll(partitionedTestDF).collect()) - } - } - - test("save()/load() - partitioned table - Append - new partition values") { - withTempPath { file => - partitionedTestDF1.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF2.save( - source = dataSourceName, - mode = SaveMode.Append, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), - partitionedTestDF.collect()) - } - } - - test("save()/load() - partitioned table - ErrorIfExists") { - withTempDir { file => - intercept[RuntimeException] { - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.ErrorIfExists, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - } - } - } - - test("save()/load() - partitioned table - Ignore") { - withTempDir { file => - partitionedTestDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Ignore) - - val path = new Path(file.getCanonicalPath) - val fs = path.getFileSystem(SparkHadoopUtil.get.conf) - assert(fs.listStatus(path).isEmpty) - } - } - - def withTable(tableName: String)(f: => Unit): Unit = { - try f finally sql(s"DROP TABLE $tableName") - } - - test("saveAsTable()/load() - non-partitioned table - Overwrite") { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - Map("dataSchema" -> dataSchema.json)) - - withTable("t") { - checkAnswer(table("t"), testDF.collect()) - } - } - - test("saveAsTable()/load() - non-partitioned table - Append") { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite) - - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append) - - withTable("t") { - checkAnswer(table("t"), testDF.unionAll(testDF).orderBy("a").collect()) - } - } - - test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") - - withTempTable("t") { - intercept[AnalysisException] { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.ErrorIfExists) - } - } - } - - test("saveAsTable()/load() - non-partitioned table - Ignore") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") - - withTempTable("t") { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Ignore) - - assert(table("t").collect().isEmpty) - } - } - - test("saveAsTable()/load() - partitioned table - simple queries") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - Map("dataSchema" -> dataSchema.json)) - - withTable("t") { - checkQueries(table("t")) - } - } - - test("saveAsTable()/load() - partitioned table - Overwrite") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) - } - } - - test("saveAsTable()/load() - partitioned table - Append") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - withTable("t") { - checkAnswer(table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) - } - } - - test("saveAsTable()/load() - partitioned table - Append - new partition values") { - partitionedTestDF1.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF2.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) - } - } - - test("saveAsTable()/load() - partitioned table - Append - mismatched partition columns") { - partitionedTestDF1.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - // Using only a subset of all partition columns - intercept[Throwable] { - partitionedTestDF2.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1")) - } - - // Using different order of partition columns - intercept[Throwable] { - partitionedTestDF2.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p2", "p1")) - } - } - - test("saveAsTable()/load() - partitioned table - ErrorIfExists") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") - - withTempTable("t") { - intercept[AnalysisException] { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.ErrorIfExists, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - } - } - } - - test("saveAsTable()/load() - partitioned table - Ignore") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") - - withTempTable("t") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Ignore, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - assert(table("t").collect().isEmpty) - } - } - - test("Hadoop style globbing") { - withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - val df = load( - source = dataSourceName, - options = Map( - "path" -> s"${file.getCanonicalPath}/p1=*/p2=???", - "dataSchema" -> dataSchema.json)) - - val expectedPaths = Set( - s"${file.getCanonicalFile}/p1=1/p2=foo", - s"${file.getCanonicalFile}/p1=2/p2=foo", - s"${file.getCanonicalFile}/p1=1/p2=bar", - s"${file.getCanonicalFile}/p1=2/p2=bar" - ).map { p => - val path = new Path(p) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - path.makeQualified(fs.getUri, fs.getWorkingDirectory).toString - } - val actualPaths = df.queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: FSBasedRelation) => - relation.paths.toSet - }.getOrElse { - fail("Expect an FSBasedRelation, but none could be found") - } - - assert(actualPaths === expectedPaths) - checkAnswer(df, partitionedTestDF.collect()) - } - } -} - -class FSBasedOrcRelationSuite extends OrcRelationTest { +class FSBasedOrcRelationSuite extends FSBasedRelationTest { override val dataSourceName: String = classOf[DefaultSource].getCanonicalName import sqlContext._ From 7b3c7c5233d370cd499ab4aa15106b0e393b23ff Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Wed, 13 May 2015 19:48:02 -0700 Subject: [PATCH 05/12] save mode fix --- .../src/main/scala/org/apache/spark/sql/hive/orc/package.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala index a85f035a9424d..93af5af196dab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala @@ -42,7 +42,7 @@ package object orc { dataFrame.save( path, source = classOf[DefaultSource].getCanonicalName, - mode = SaveMode.Overwrite) + mode) } } From 3c9038e49a5370e861f63cbc6a40370a3d21163a Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Thu, 14 May 2015 13:19:30 -0700 Subject: [PATCH 06/12] resolve review comments --- .../sql/hive/orc/HadoopTypeConverter.scala | 39 +---- .../spark/sql/hive/orc/OrcFileOperator.scala | 22 +-- .../spark/sql/hive/orc/OrcFilters.scala | 133 ++++++------------ .../spark/sql/hive/orc/OrcRelation.scala | 94 +++++-------- .../sql/hive/orc/OrcTableOperations.scala | 51 +++---- .../apache/spark/sql/hive/orc/package.scala | 19 +-- .../hive/orc/OrcPartitionDiscoverySuite.scala | 7 +- .../spark/sql/hive/orc/OrcQuerySuite.scala | 29 +++- 8 files changed, 157 insertions(+), 237 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala index aabc5477b05a6..713c076aee457 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.hive.orc -import org.apache.hadoop.hive.common.`type`.HiveVarchar -import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} + import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ -import org.apache.spark.sql.catalyst.expressions.{Row, MutableRow} -import scala.collection.JavaConversions._ +import org.apache.spark.sql.catalyst.expressions.MutableRow +import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} /** * We can consolidate TableReader.unwrappers and HiveInspectors.wrapperFor to use @@ -59,35 +58,5 @@ private[hive] object HadoopTypeConverter extends HiveInspectors { /** * Wraps with Hive types based on object inspector. */ - def wrappers(oi: ObjectInspector): Any => Any = oi match { - case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) - - case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveShim.createDecimal(o.asInstanceOf[BigDecimal].underlying()) - - case soi: StandardStructObjectInspector => - val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) - (o: Any) => { - val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row].toSeq).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) - } - struct - } - - case loi: ListObjectInspector => - val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) - - case moi: MapObjectInspector => - val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) - val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) - (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => - keyWrapper(key) -> valueWrapper(value) - }) - - case _ => - identity[Any] - } + def wrappers(oi: ObjectInspector): Any => Any = wrapperFor(oi) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 6805677e84ea9..4dd2d8951b728 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.hive.orc -import java.io.IOException - import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector + import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.hive.HiveMetastoreTypes @@ -31,7 +30,7 @@ import org.apache.spark.sql.types.StructType private[orc] object OrcFileOperator extends Logging{ def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = { - var conf = config.getOrElse(new Configuration) + val conf = config.getOrElse(new Configuration) val fspath = new Path(pathStr) val fs = fspath.getFileSystem(conf) val orcFiles = listOrcFiles(pathStr, conf) @@ -53,19 +52,6 @@ private[orc] object OrcFileOperator extends Logging{ readerInspector } - def deletePath(pathStr: String, conf: Configuration): Unit = { - val fspath = new Path(pathStr) - val fs = fspath.getFileSystem(conf) - try { - fs.delete(fspath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${fspath.toString} prior" - + s" to InsertIntoOrcTable:\n${e.toString}") - } - } - def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) val fs = origPath.getFileSystem(conf) @@ -80,8 +66,6 @@ private[orc] object OrcFileOperator extends Logging{ throw new IllegalArgumentException( s"orcFileOperator: path $path does not have valid orc files matching the pattern") } - logInfo("Qualified file list: ") - paths.foreach{x=>logInfo(x.toString)} paths } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 0a9924b139a48..eda1cffe49810 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -22,100 +22,55 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder import org.apache.spark.Logging import org.apache.spark.sql.sources._ -private[sql] object OrcFilters extends Logging { - +/** + * It may be optimized by push down partial filters. But we are conservative here. + * Because if some filters fail to be parsed, the tree may be corrupted, + * and cannot be used anymore. + */ +private[orc] object OrcFilters extends Logging { def createFilter(expr: Array[Filter]): Option[SearchArgument] = { - if (expr == null || expr.size == 0) return None - var sarg: Option[Builder] = Some(SearchArgument.FACTORY.newBuilder()) - sarg.get.startAnd() - expr.foreach { - x => { - sarg match { - case Some(s1) => sarg = createFilter(x, s1) - case _ => None - } - } - } - sarg match { - case Some(b) => Some(b.end.build) - case _ => None + if (expr.nonEmpty) { + expr.foldLeft(Some(SearchArgument.FACTORY.newBuilder().startAnd()): Option[Builder]) { + (maybeBuilder, e) => createFilter(e, maybeBuilder) + }.map(_.end().build()) + } else { + None } } - def createFilter(expression: Filter, builder: Builder): Option[Builder] = { - expression match { - case p@And(left: Filter, right: Filter) => { - val b1 = builder.startAnd() - val b2 = createFilter(left, b1) - b2 match { - case Some(b) => val b3 = createFilter(right, b) - if (b3.isDefined) { - Some(b3.get.end) - } else { - None - } - case _ => None - } - } - case p@Or(left: Filter, right: Filter) => { - val b1 = builder.startOr() - val b2 = createFilter(left, b1) - b2 match { - case Some(b) => val b3 = createFilter(right, b) - if (b3.isDefined) { - Some(b3.get.end) - } else { - None - } - case _ => None - } - } - case p@Not(child: Filter) => { - val b1 = builder.startNot() - val b2 = createFilter(child, b1) - b2 match { - case Some(b) => Some(b.end) - case _ => None - } - } - case p@EqualTo(attribute: String, value: Any) => { - val b1 = builder.equals(attribute, value) - Some(b1) - } - case p@LessThan(attribute: String, value: Any) => { - val b1 = builder.lessThan(attribute ,value) - Some(b1) - } - case p@LessThanOrEqual(attribute: String, value: Any) => { - val b1 = builder.lessThanEquals(attribute, value) - Some(b1) - } - case p@GreaterThan(attribute: String, value: Any) => { - val b1 = builder.startNot().lessThanEquals(attribute, value).end() - Some(b1) - } - case p@GreaterThanOrEqual(attribute: String, value: Any) => { - val b1 = builder.startNot().lessThan(attribute, value).end() - Some(b1) - } - case p@IsNull(attribute: String) => { - val b1 = builder.isNull(attribute) - Some(b1) - } - case p@IsNotNull(attribute: String) => { - val b1 = builder.startNot().isNull(attribute).end() - Some(b1) - } - case p@In(attribute: String, values: Array[Any]) => { - val b1 = builder.in(attribute, values) - Some(b1) + private def createFilter(expression: Filter, maybeBuilder: Option[Builder]): Option[Builder] = { + maybeBuilder.flatMap { builder => + expression match { + case p@And(left, right) => + for { + lhs <- createFilter(left, Some(builder.startAnd())) + rhs <- createFilter(right, Some(lhs)) + } yield rhs.end() + case p@Or(left, right) => + for { + lhs <- createFilter(left, Some(builder.startOr())) + rhs <- createFilter(right, Some(lhs)) + } yield rhs.end() + case p@Not(child) => + createFilter(child, Some(builder.startNot())).map(_.end()) + case p@EqualTo(attribute, value) => + Some(builder.equals(attribute, value)) + case p@LessThan(attribute, value) => + Some(builder.lessThan(attribute, value)) + case p@LessThanOrEqual(attribute, value) => + Some(builder.lessThanEquals(attribute, value)) + case p@GreaterThan(attribute, value) => + Some(builder.startNot().lessThanEquals(attribute, value).end()) + case p@GreaterThanOrEqual(attribute, value) => + Some(builder.startNot().lessThan(attribute, value).end()) + case p@IsNull(attribute) => + Some(builder.isNull(attribute)) + case p@IsNotNull(attribute) => + Some(builder.startNot().isNull(attribute).end()) + case p@In(attribute, values) => + Some(builder.in(attribute, values)) + case _ => None } - // not supported in filter - // case p@EqualNullSafe(left: String, right: String) => { - // val b1 = builder.nullSafeEquals(left, right) - // Some(b1) - // } - case _ => None } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 816f3794a6a02..c68a58647cad7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -26,11 +26,11 @@ import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfoUtils, TypeInfo} import org.apache.hadoop.io.{Writable, NullWritable} import org.apache.hadoop.mapred.{RecordWriter, Reporter, JobConf} import org.apache.hadoop.mapreduce.{TaskID, TaskAttemptContext} + import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.hive.HiveMetastoreTypes import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} @@ -39,7 +39,6 @@ import scala.collection.JavaConversions._ private[sql] class DefaultSource extends FSBasedRelationProvider { - def createRelation( sqlContext: SQLContext, paths: Array[String], @@ -54,7 +53,7 @@ private[sql] class DefaultSource extends FSBasedRelationProvider { private[sql] class OrcOutputWriter extends OutputWriter with SparkHadoopMapRedUtil { - var recordWriter: RecordWriter[NullWritable, Writable] = _ + var taskAttemptContext: TaskAttemptContext = _ var serializer: OrcSerde = _ var wrappers: Array[Any => Any] = _ @@ -62,90 +61,75 @@ private[sql] class OrcOutputWriter extends OutputWriter with SparkHadoopMapRedUt var path: String = _ var dataSchema: StructType = _ var fieldOIs: Array[ObjectInspector] = _ - var standardOI: StructObjectInspector = _ - + var structOI: StructObjectInspector = _ + var outputData: Array[Any] = _ + lazy val recordWriter: RecordWriter[NullWritable, Writable] = { + created = true + val conf = taskAttemptContext.getConfiguration + val taskId: TaskID = taskAttemptContext.getTaskAttemptID.getTaskID + val partition: Int = taskId.getId + val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc" + val file = new Path(path, filename) + val fs = file.getFileSystem(conf) + val outputFormat = new OrcOutputFormat() + outputFormat.getRecordWriter(fs, + conf.asInstanceOf[JobConf], + file.toUri.getPath, Reporter.NULL) + .asInstanceOf[org.apache.hadoop.mapred.RecordWriter[NullWritable, Writable]] + } override def init(path: String, dataSchema: StructType, context: TaskAttemptContext): Unit = { this.path = path - this.dataSchema = dataSchema taskAttemptContext = context + val orcSchema = HiveMetastoreTypes.toMetastoreType(dataSchema) + serializer = new OrcSerde + val typeInfo: TypeInfo = + TypeInfoUtils.getTypeInfoFromTypeString(orcSchema) + structOI = TypeInfoUtils + .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) + .asInstanceOf[StructObjectInspector] + fieldOIs = structOI + .getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + outputData = new Array[Any](fieldOIs.length) + wrappers = fieldOIs.map(HadoopTypeConverter.wrappers) } - // Avoid create empty file without schema attached - private def initWriter() = { - if (!created) { - created = true - val conf = taskAttemptContext.getConfiguration - val outputFormat = new OrcOutputFormat() - val taskId: TaskID = taskAttemptContext.getTaskAttemptID.getTaskID - val partition: Int = taskId.getId - val filename = s"part-r-${partition}-${System.currentTimeMillis}.orc" - val file = new Path(path, filename) - val fs = file.getFileSystem(conf) - val orcSchema = HiveMetastoreTypes.toMetastoreType(dataSchema) - - serializer = new OrcSerde - val typeInfo: TypeInfo = - TypeInfoUtils.getTypeInfoFromTypeString(orcSchema) - standardOI = TypeInfoUtils - .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) - .asInstanceOf[StructObjectInspector] - fieldOIs = standardOI - .getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - wrappers = fieldOIs.map(HadoopTypeConverter.wrappers) - recordWriter = { - outputFormat.getRecordWriter(fs, - conf.asInstanceOf[JobConf], - file.toUri.getPath, Reporter.NULL) - .asInstanceOf[org.apache.hadoop.mapred.RecordWriter[NullWritable, Writable]] - } - } - } override def write(row: Row): Unit = { - initWriter() var i = 0 - val outputData = new Array[Any](fieldOIs.length) while (i < row.length) { outputData(i) = wrappers(i)(row(i)) i += 1 } - val writable = serializer.serialize(outputData, standardOI) + val writable = serializer.serialize(outputData, structOI) recordWriter.write(NullWritable.get(), writable) } override def close(): Unit = { - if (recordWriter != null) { + if (created) { recordWriter.close(Reporter.NULL) } } } - @DeveloperApi -private[sql] case class OrcRelation(override val paths: Array[String], +private[sql] case class OrcRelation( + override val paths: Array[String], parameters: Map[String, String], maybeSchema: Option[StructType] = None, maybePartitionSpec: Option[PartitionSpec] = None)( @transient val sqlContext: SQLContext) extends FSBasedRelation(paths, maybePartitionSpec) with Logging { - self: Product => - @transient val conf = sqlContext.sparkContext.hadoopConfiguration - - - override def dataSchema: StructType = - maybeSchema.getOrElse(OrcFileOperator.readSchema(paths(0), Some(conf))) + override val dataSchema: StructType = + maybeSchema.getOrElse(OrcFileOperator.readSchema(paths(0), + Some(sqlContext.sparkContext.hadoopConfiguration))) override def outputWriterClass: Class[_ <: OutputWriter] = classOf[OrcOutputWriter] - /** Attributes */ - var output: Seq[Attribute] = schema.toAttributes override def needConversion: Boolean = false - // Equals must also take into account the output attributes so that we can distinguish between - // different instances of the same relation, override def equals(other: Any): Boolean = other match { case that: OrcRelation => paths.toSet == that.paths.toSet && @@ -162,6 +146,7 @@ private[sql] case class OrcRelation(override val paths: Array[String], schema, maybePartitionSpec) } + override def buildScan(requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String]): RDD[Row] = { @@ -169,8 +154,3 @@ private[sql] case class OrcRelation(override val paths: Array[String], OrcTableScan(output, this, filters, inputPaths).execute() } } - -private[sql] object OrcRelation extends Logging { - // Default partition name to use when the partition column value is null or empty string. - val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala index 94c78a14524b5..2163b0ce70e99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala @@ -20,50 +20,52 @@ package org.apache.spark.sql.hive.orc import java.util._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc._ import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.spark.rdd.RDD + +import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.sources.Filter import org.apache.spark.{Logging, SerializableWritable} + +/* Implicit conversions */ import scala.collection.JavaConversions._ -case class OrcTableScan(attributes: Seq[Attribute], +private[orc] case class OrcTableScan(attributes: Seq[Attribute], @transient relation: OrcRelation, filters: Array[Filter], inputPaths: Array[String]) extends Logging { - @transient val sqlContext = relation.sqlContext - val path = relation.paths(0) + @transient private val sqlContext = relation.sqlContext - def addColumnIds(output: Seq[Attribute], - relation: OrcRelation, conf: Configuration) { - val ids = - output.map(a => - relation.dataSchema.toAttributes.indexWhere(_.name == a.name): Integer) - .filter(_ >= 0) - val names = attributes.map(_.name) - val sorted = ids.zip(names).sorted - HiveShim.appendReadColumns(conf, sorted.map(_._1), sorted.map(_._2)) + private def addColumnIds( + output: Seq[Attribute], + relation: OrcRelation, + conf: Configuration): Unit = { + val ids = output.map(a => relation.dataSchema.fieldIndex(a.name): Integer) + val (sortedIds, sortedNames) = ids.zip(attributes.map(_.name)).sorted.unzip + HiveShim.appendReadColumns(conf, sortedIds, sortedNames) } - def buildFilter(job: Job, filters: Array[Filter]): Unit = { + private def buildFilter(job: Job, filters: Array[Filter]): Unit = { if (ORC_FILTER_PUSHDOWN_ENABLED) { val conf: Configuration = job.getConfiguration - val recordFilter = OrcFilters.createFilter(filters) - if (recordFilter.isDefined) { - conf.set(SARG_PUSHDOWN, toKryo(recordFilter.get)) - conf.setBoolean(INDEX_FILTER, true) + OrcFilters.createFilter(filters).foreach { f => + conf.set(SARG_PUSHDOWN, toKryo(f)) + conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) } } } // Transform all given raw `Writable`s into `Row`s. - def fillObject(conf: Configuration, + private def fillObject( + path: String, + conf: Configuration, iterator: Iterator[org.apache.hadoop.io.Writable], nonPartitionKeyAttrs: Seq[(Attribute, Int)], mutableRow: MutableRow): Iterator[Row] = { @@ -77,7 +79,6 @@ case class OrcTableScan(attributes: Seq[Attribute], // Map each tuple to a row object iterator.map { value => val raw = deserializer.deserialize(value) - logDebug("Raw data: " + raw) var i = 0 while (i < fieldRefs.length) { val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) @@ -105,11 +106,13 @@ case class OrcTableScan(attributes: Seq[Attribute], Class[_ <: org.apache.hadoop.mapred.InputFormat[NullWritable, Writable]]] val rdd = sc.hadoopRDD(conf.asInstanceOf[JobConf], - inputClass, classOf[NullWritable], classOf[Writable]).map(_._2) - val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + inputClass, classOf[NullWritable], classOf[Writable]) + .asInstanceOf[HadoopRDD[NullWritable, Writable]] val wrappedConf = new SerializableWritable(conf) - val rowRdd: RDD[Row] = rdd.mapPartitions { iter => - fillObject(wrappedConf.value, iter, attributes.zipWithIndex, mutableRow) + val rowRdd: RDD[Row] = rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iter) => + val pathStr = split.getPath.toString + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + fillObject(pathStr, wrappedConf.value, iter.map(_._2), attributes.zipWithIndex, mutableRow) } rowRdd } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala index 93af5af196dab..b219fbb44ca0d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala @@ -21,19 +21,21 @@ import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.Kryo import org.apache.commons.codec.binary.Base64 import org.apache.spark.sql.{SaveMode, DataFrame} -import scala.reflect.runtime.universe.{TypeTag, typeTag} package object orc { implicit class OrcContext(sqlContext: HiveContext) { import sqlContext._ @scala.annotation.varargs - def orcFile(paths: String*): DataFrame = { - if (paths.isEmpty) { - emptyDataFrame - } else { - val orcRelation = OrcRelation(paths.toArray, Map.empty)(sqlContext) - sqlContext.baseRelationToDataFrame(orcRelation) + def orcFile(path: String, paths: String*): DataFrame = { + val pathArray: Array[String] = { + if (paths.isEmpty) { + Array(path) + } else { + paths.toArray ++ Array(path) + } } + val orcRelation = OrcRelation(pathArray, Map.empty)(sqlContext) + sqlContext.baseRelationToDataFrame(orcRelation) } } @@ -49,8 +51,7 @@ package object orc { // Flags for orc copression, predicates pushdown, etc. val orcDefaultCompressVar = "hive.exec.orc.default.compress" var ORC_FILTER_PUSHDOWN_ENABLED = true - val SARG_PUSHDOWN = "sarg.pushdown"; - val INDEX_FILTER = "hive.optimize.index.filter" + val SARG_PUSHDOWN = "sarg.pushdown" def toKryo(input: Any): String = { val out = new Output(4 * 1024, 10 * 1024 * 1024); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 7ebf3c6eced26..31a829a81124d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.orc import java.io.File +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive @@ -37,7 +38,7 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { - val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal def withTempDir(f: File => Unit): Unit = { val dir = Utils.createTempDir().getCanonicalFile @@ -187,7 +188,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before "org.apache.spark.sql.hive.orc.DefaultSource", Map( "path" -> base.getCanonicalPath, - OrcRelation.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + ConfVars.DEFAULTPARTITIONNAME.varname -> defaultPartitionName)) orcRelation.registerTempTable("t") @@ -232,7 +233,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before "org.apache.spark.sql.hive.orc.DefaultSource", Map( "path" -> base.getCanonicalPath, - OrcRelation.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + ConfVars.DEFAULTPARTITIONNAME.varname -> defaultPartitionName)) orcRelation.registerTempTable("t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 3c596d0654324..475af3d4c94e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -153,13 +153,40 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { data.toDF().saveAsOrcFile(tempDir) val f = TestHive.orcFile(tempDir) f.registerTempTable("tmp") + + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // expr = leaf-0 var rdd = sql("SELECT name FROM tmp where age <= 5") assert(rdd.count() == 5) + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // expr = (not leaf-0) rdd = sql("SELECT name, contacts FROM tmp where age > 5") assert(rdd.count() == 5) - val contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) + var contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) assert(contacts.count() == 10) + + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // leaf-1 = (LESS_THAN age 8) + // expr = (and (not leaf-0) leaf-1) + rdd = sql("SELECT name, contacts FROM tmp where age > 5 and age < 8") + assert(rdd.count() == 2) + contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) + assert(contacts.count() == 4) + + // ppd: + // leaf-0 = (LESS_THAN age 2) + // leaf-1 = (LESS_THAN_EQUALS age 8) + // expr = (or leaf-0 (not leaf-1)) + rdd = sql("SELECT name, contacts FROM tmp where age < 2 or age > 8") + assert(rdd.count() == 3) + contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) + assert(contacts.count() == 6) + + Utils.deleteRecursively(new File(tempDir)) } From 2650a422643c8749f7c75e38152df649577337ce Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Thu, 14 May 2015 13:22:07 -0700 Subject: [PATCH 07/12] resolve review comments --- .../main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index c68a58647cad7..44ac728b09aa3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -35,6 +35,8 @@ import org.apache.spark.sql.hive.HiveMetastoreTypes import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.sources._ + +/* Implicit conversions */ import scala.collection.JavaConversions._ From d7344968aacaca0c418653c0ed3bd4daa5f78409 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 16 May 2015 00:17:51 +0800 Subject: [PATCH 08/12] Polishes the ORC data source --- .../scala/org/apache/spark/sql/SQLConf.scala | 7 +- .../spark/sql/parquet/ParquetTest.scala | 61 +---- .../apache/spark/sql/test/SQLTestUtils.scala | 81 ++++++ .../sql/hive/orc/HadoopTypeConverter.scala | 3 +- .../spark/sql/hive/orc/OrcFileOperator.scala | 14 +- .../spark/sql/hive/orc/OrcFilters.scala | 146 ++++++++--- .../spark/sql/hive/orc/OrcRelation.scala | 248 +++++++++++++----- .../sql/hive/orc/OrcTableOperations.scala | 119 --------- .../apache/spark/sql/hive/orc/package.scala | 24 +- .../spark/sql/hive/orc/NewOrcQuerySuite.scala | 177 +++++++++++++ ...e.scala => OrcHadoopFsRelationSuite.scala} | 5 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 1 + .../spark/sql/hive/orc/OrcQuerySuite.scala | 27 +- .../sql/sources/hadoopFsRelationSuites.scala | 6 +- 14 files changed, 591 insertions(+), 328 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala rename sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/{OrcRelationSuite.scala => OrcHadoopFsRelationSuite.scala} (94%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f07bb196c11ec..6da910e332e9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -43,6 +43,8 @@ private[spark] object SQLConf { val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" + val ORC_FILTER_PUSHDOWN_ENABLED = "spark.sql.orc.filterPushdown" + val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath" val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" @@ -143,6 +145,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def parquetUseDataSourceApi = getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + private[spark] def orcFilterPushDown = + getConf(ORC_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + /** When true uses verifyPartitionPath to prune the path which is not exists. */ private[spark] def verifyPartitionPath = getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean @@ -254,7 +259,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean - + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 7a73b6f1ac601..516ba373f41d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -21,10 +21,9 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import scala.util.Try -import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} -import org.apache.spark.util.Utils +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{DataFrame, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,54 +32,9 @@ import org.apache.spark.util.Utils * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest { - val sqlContext: SQLContext - +private[sql] trait ParquetTest extends SQLTestUtils { import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} - import sqlContext.{conf, sparkContext} - - protected def configuration = sparkContext.hadoopConfiguration - - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) - (keys, values).zipped.foreach(conf.setConf) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConf(key, value) - case (key, None) => conf.unsetConf(key) - } - } - } - - /** - * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If - * a file/directory is created there by `f`, it will be delete after `f` returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - /** - * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` - * returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) - } + import sqlContext.sparkContext /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -105,13 +59,6 @@ private[sql] trait ParquetTest { withParquetFile(data)(path => f(sqlContext.read.parquet(path))) } - /** - * Drops temporary table `tableName` after calling `f`. - */ - protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally sqlContext.dropTempTable(tableName) - } - /** * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a * temporary table named `tableName`, then call `f`. The temporary table together with the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala new file mode 100644 index 0000000000000..75d290625ec38 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.io.File + +import scala.util.Try + +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils + +trait SQLTestUtils { + val sqlContext: SQLContext + + import sqlContext.{conf, sparkContext} + + protected def configuration = sparkContext.hadoopConfiguration + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) + (keys, values).zipped.foreach(conf.setConf) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConf(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + /** + * Drops temporary table `tableName` after calling `f`. + */ + protected def withTempTable(tableName: String)(f: => Unit): Unit = { + try f finally sqlContext.dropTempTable(tableName) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala index 713c076aee457..b5b5e56079cc3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.hive.orc - import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} +import org.apache.spark.sql.hive.HiveInspectors /** * We can consolidate TableReader.unwrappers and HiveInspectors.wrapperFor to use diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 4dd2d8951b728..1e51173a19882 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -28,28 +28,25 @@ import org.apache.spark.sql.hive.HiveMetastoreTypes import org.apache.spark.sql.types.StructType private[orc] object OrcFileOperator extends Logging{ - def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = { val conf = config.getOrElse(new Configuration) val fspath = new Path(pathStr) val fs = fspath.getFileSystem(conf) val orcFiles = listOrcFiles(pathStr, conf) - OrcFile.createReader(fs, orcFiles(0)) + + // TODO Need to consider all files when schema evolution is taken into account. + OrcFile.createReader(fs, orcFiles.head) } def readSchema(path: String, conf: Option[Configuration]): StructType = { val reader = getFileReader(path, conf) - val readerInspector: StructObjectInspector = reader.getObjectInspector - .asInstanceOf[StructObjectInspector] + val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] } def getObjectInspector(path: String, conf: Option[Configuration]): StructObjectInspector = { - val reader = getFileReader(path, conf) - val readerInspector: StructObjectInspector = reader.getObjectInspector - .asInstanceOf[StructObjectInspector] - readerInspector + getFileReader(path, conf).getObjectInspector.asInstanceOf[StructObjectInspector] } def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { @@ -66,6 +63,7 @@ private[orc] object OrcFileOperator extends Logging{ throw new IllegalArgumentException( s"orcFileOperator: path $path does not have valid orc files matching the pattern") } + paths } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index eda1cffe49810..9bee4f59b5854 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.hive.orc +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder +import org.apache.hadoop.hive.serde2.io.DateWritable + import org.apache.spark.Logging import org.apache.spark.sql.sources._ @@ -29,48 +32,113 @@ import org.apache.spark.sql.sources._ */ private[orc] object OrcFilters extends Logging { def createFilter(expr: Array[Filter]): Option[SearchArgument] = { - if (expr.nonEmpty) { - expr.foldLeft(Some(SearchArgument.FACTORY.newBuilder().startAnd()): Option[Builder]) { - (maybeBuilder, e) => createFilter(e, maybeBuilder) - }.map(_.end().build()) - } else { - None + expr.reduceOption(And).flatMap { conjunction => + val builder = SearchArgument.FACTORY.newBuilder() + buildSearchArgument(conjunction, builder).map(_.build()) } } - private def createFilter(expression: Filter, maybeBuilder: Option[Builder]): Option[Builder] = { - maybeBuilder.flatMap { builder => - expression match { - case p@And(left, right) => - for { - lhs <- createFilter(left, Some(builder.startAnd())) - rhs <- createFilter(right, Some(lhs)) - } yield rhs.end() - case p@Or(left, right) => - for { - lhs <- createFilter(left, Some(builder.startOr())) - rhs <- createFilter(right, Some(lhs)) - } yield rhs.end() - case p@Not(child) => - createFilter(child, Some(builder.startNot())).map(_.end()) - case p@EqualTo(attribute, value) => - Some(builder.equals(attribute, value)) - case p@LessThan(attribute, value) => - Some(builder.lessThan(attribute, value)) - case p@LessThanOrEqual(attribute, value) => - Some(builder.lessThanEquals(attribute, value)) - case p@GreaterThan(attribute, value) => - Some(builder.startNot().lessThanEquals(attribute, value).end()) - case p@GreaterThanOrEqual(attribute, value) => - Some(builder.startNot().lessThan(attribute, value).end()) - case p@IsNull(attribute) => - Some(builder.isNull(attribute)) - case p@IsNotNull(attribute) => - Some(builder.startNot().isNull(attribute).end()) - case p@In(attribute, values) => - Some(builder.in(attribute, values)) - case _ => None - } + private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { + def newBuilder = SearchArgument.FACTORY.newBuilder() + + def isSearchableLiteral(value: Any) = value match { + // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. + case _: String | _: Long | _: Double | _: DateWritable | _: HiveDecimal | _: HiveChar | + _: HiveVarchar | _: Byte | _: Short | _: Integer | _: Float => true + case _ => false + } + + // lian: I probably missed something here, and had to end up with a pretty weird double-checking + // pattern when converting `And`/`Or`/`Not` filters. + // + // The annoying part is that, `SearchArgument` builder methods like `startAnd()` `startOr()`, + // and `startNot()` mutate internal state of the builder instance. This forces us to translate + // all convertible filters with a single builder instance. However, before actually converting a + // filter, we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible + // filter is found, we may already end up with a builder whose internal state is inconsistent. + // + // For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and + // then try to convert its children. Say we convert `left` child successfully, but find that + // `right` child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is + // inconsistent now. + // + // The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + // children with brand new builders, and only do the actual conversion with the right builder + // instance when the children are proven to be convertible. + // + // P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. + // Usage of builder methods mentioned above can only be found in test code, where all tested + // filters are known to be convertible. + + expression match { + case And(left, right) => + val tryLeft = buildSearchArgument(left, newBuilder) + val tryRight = buildSearchArgument(right, newBuilder) + + val conjunction = for { + _ <- tryLeft + _ <- tryRight + lhs <- buildSearchArgument(left, builder.startAnd()) + rhs <- buildSearchArgument(right, lhs) + } yield rhs.end() + + // For filter `left AND right`, we can still push down `left` even if `right` is not + // convertible, and vice versa. + conjunction + .orElse(tryLeft.flatMap(_ => buildSearchArgument(left, builder))) + .orElse(tryRight.flatMap(_ => buildSearchArgument(right, builder))) + + case And(left, right) => + for { + _ <- buildSearchArgument(left, newBuilder) + _ <- buildSearchArgument(right, newBuilder) + lhs <- buildSearchArgument(left, builder.startOr()) + rhs <- buildSearchArgument(right, lhs) + } yield rhs.end() + + case Not(child) => + for { + _ <- buildSearchArgument(child, newBuilder) + negate <- buildSearchArgument(child, builder.startNot()) + } yield negate.end() + + case EqualTo(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.equals(attribute, _)) + + case LessThan(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.lessThan(attribute, _)) + + case LessThanOrEqual(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.lessThanEquals(attribute, _)) + + case GreaterThan(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.startNot().lessThanEquals(attribute, _).end()) + + case GreaterThanOrEqual(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.startNot().lessThan(attribute, _).end()) + + case IsNull(attribute) => + Some(builder.isNull(attribute)) + + case IsNotNull(attribute) => + Some(builder.startNot().isNull(attribute).end()) + + case In(attribute, values) => + Option(values) + .filter(_.forall(isSearchableLiteral)) + .map(builder.in(attribute, _)) + + case _ => None } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 44ac728b09aa3..3e3c8a9e619d5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -17,99 +17,114 @@ package org.apache.spark.sql.hive.orc -import java.util.Objects +import java.util.{Objects, Properties} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.ql.io.orc.{OrcSerde, OrcOutputFormat} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StructObjectInspector} -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfoUtils, TypeInfo} -import org.apache.hadoop.io.{Writable, NullWritable} -import org.apache.hadoop.mapred.{RecordWriter, Reporter, JobConf} -import org.apache.hadoop.mapreduce.{TaskID, TaskAttemptContext} - -import org.apache.spark.Logging +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit} +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, RecordWriter, Reporter} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.hive.HiveMetastoreTypes +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveShim} +import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.sources._ +import org.apache.spark.{Logging, SerializableWritable} /* Implicit conversions */ import scala.collection.JavaConversions._ - -private[sql] class DefaultSource extends FSBasedRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider { def createRelation( sqlContext: SQLContext, paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], - parameters: Map[String, String]): FSBasedRelation ={ + parameters: Map[String, String]): HadoopFsRelation = { val partitionSpec = partitionColumns.map(PartitionSpec(_, Seq.empty[Partition])) - OrcRelation(paths, parameters, - schema, partitionSpec)(sqlContext) + OrcRelation(paths, parameters, schema, partitionSpec)(sqlContext) } } +private[orc] class OrcOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter with SparkHadoopMapRedUtil { -private[sql] class OrcOutputWriter extends OutputWriter with SparkHadoopMapRedUtil { - - var taskAttemptContext: TaskAttemptContext = _ - var serializer: OrcSerde = _ - var wrappers: Array[Any => Any] = _ - var created = false - var path: String = _ - var dataSchema: StructType = _ - var fieldOIs: Array[ObjectInspector] = _ - var structOI: StructObjectInspector = _ - var outputData: Array[Any] = _ - lazy val recordWriter: RecordWriter[NullWritable, Writable] = { - created = true - val conf = taskAttemptContext.getConfiguration - val taskId: TaskID = taskAttemptContext.getTaskAttemptID.getTaskID - val partition: Int = taskId.getId - val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc" - val file = new Path(path, filename) - val fs = file.getFileSystem(conf) - val outputFormat = new OrcOutputFormat() - outputFormat.getRecordWriter(fs, - conf.asInstanceOf[JobConf], - file.toUri.getPath, Reporter.NULL) - .asInstanceOf[org.apache.hadoop.mapred.RecordWriter[NullWritable, Writable]] + private val serializer = { + val table = new Properties() + table.setProperty("columns", dataSchema.fieldNames.mkString(",")) + table.setProperty("columns.types", dataSchema.map { f => + HiveMetastoreTypes.toMetastoreType(f.dataType) + }.mkString(":")) + + val serde = new OrcSerde + serde.initialize(context.getConfiguration, table) + serde } - override def init(path: String, - dataSchema: StructType, - context: TaskAttemptContext): Unit = { - this.path = path - taskAttemptContext = context - val orcSchema = HiveMetastoreTypes.toMetastoreType(dataSchema) - serializer = new OrcSerde - val typeInfo: TypeInfo = - TypeInfoUtils.getTypeInfoFromTypeString(orcSchema) - structOI = TypeInfoUtils + // Object inspector converted from the schema of the relation to be written. + private val structOI = { + val typeInfo = + TypeInfoUtils.getTypeInfoFromTypeString( + HiveMetastoreTypes.toMetastoreType(dataSchema)) + + TypeInfoUtils .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) .asInstanceOf[StructObjectInspector] - fieldOIs = structOI - .getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - outputData = new Array[Any](fieldOIs.length) - wrappers = fieldOIs.map(HadoopTypeConverter.wrappers) + } + + // Used to hold temporary `Writable` fields of the next row to be written. + private val reusableOutputBuffer = new Array[Any](dataSchema.length) + + // Used to convert Catalyst values into Hadoop `Writable`s. + private val wrappers = structOI.getAllStructFieldRefs.map { ref => + HadoopTypeConverter.wrappers(ref.getFieldObjectInspector) + }.toArray + + // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this + // flag to decide whether `OrcRecordWriter.close()` needs to be called. + private var recordWriterInstantiated = false + + private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { + recordWriterInstantiated = true + + val conf = context.getConfiguration + val partition = context.getTaskAttemptID.getTaskID.getId + val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc" + + new OrcOutputFormat().getRecordWriter( + new Path(path, filename).getFileSystem(conf), + conf.asInstanceOf[JobConf], + new Path(path, filename).toUri.getPath, + Reporter.NULL + ).asInstanceOf[RecordWriter[NullWritable, Writable]] } override def write(row: Row): Unit = { var i = 0 while (i < row.length) { - outputData(i) = wrappers(i)(row(i)) + reusableOutputBuffer(i) = wrappers(i)(row(i)) i += 1 } - val writable = serializer.serialize(outputData, structOI) - recordWriter.write(NullWritable.get(), writable) + + recordWriter.write( + NullWritable.get(), + serializer.serialize(reusableOutputBuffer, structOI)) } override def close(): Unit = { - if (created) { + if (recordWriterInstantiated) { recordWriter.close(Reporter.NULL) } } @@ -122,13 +137,16 @@ private[sql] case class OrcRelation( maybeSchema: Option[StructType] = None, maybePartitionSpec: Option[PartitionSpec] = None)( @transient val sqlContext: SQLContext) - extends FSBasedRelation(paths, maybePartitionSpec) + extends HadoopFsRelation(maybePartitionSpec) with Logging { - override val dataSchema: StructType = - maybeSchema.getOrElse(OrcFileOperator.readSchema(paths(0), - Some(sqlContext.sparkContext.hadoopConfiguration))) - override def outputWriterClass: Class[_ <: OutputWriter] = classOf[OrcOutputWriter] + override val dataSchema: StructType = maybeSchema.getOrElse { + OrcFileOperator.readSchema( + paths.head, Some(sqlContext.sparkContext.hadoopConfiguration)) + } + + override def userDefinedPartitionColumns: Option[StructType] = + maybePartitionSpec.map(_.partitionColumns) override def needConversion: Boolean = false @@ -155,4 +173,106 @@ private[sql] case class OrcRelation( val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes OrcTableScan(output, this, filters, inputPaths).execute() } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, dataSchema, context) + } + } + } +} + +private[orc] case class OrcTableScan( + attributes: Seq[Attribute], + @transient relation: OrcRelation, + filters: Array[Filter], + inputPaths: Array[String]) extends Logging { + @transient private val sqlContext = relation.sqlContext + + private def addColumnIds( + output: Seq[Attribute], + relation: OrcRelation, + conf: Configuration): Unit = { + val ids = output.map(a => relation.dataSchema.fieldIndex(a.name): Integer) + val (sortedIds, sortedNames) = ids.zip(attributes.map(_.name)).sorted.unzip + HiveShim.appendReadColumns(conf, sortedIds, sortedNames) + } + + // Transform all given raw `Writable`s into `Row`s. + private def fillObject( + path: String, + conf: Configuration, + iterator: Iterator[Writable], + nonPartitionKeyAttrs: Seq[(Attribute, Int)], + mutableRow: MutableRow): Iterator[Row] = { + val deserializer = new OrcSerde + val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { + case (attr, ordinal) => + soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal + }.unzip + val unwrappers = HadoopTypeConverter.unwrappers(fieldRefs) + // Map each tuple to a row object + iterator.map { value => + val raw = deserializer.deserialize(value) + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 + } + mutableRow: Row + } + } + + def execute(): RDD[Row] = { + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration + + // Tries to push down filters if ORC filter push-down is enabled + if (sqlContext.conf.orcFilterPushDown) { + OrcFilters.createFilter(filters).foreach { f => + conf.set(SARG_PUSHDOWN, f.toKryo) + conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } + } + + // Sets requested columns + addColumnIds(attributes, relation, conf) + + if (inputPaths.nonEmpty) { + FileInputFormat.setInputPaths(job, inputPaths.map(new Path(_)): _*) + } + + val inputFormatClass = + classOf[OrcInputFormat] + .asInstanceOf[Class[_ <: MapRedInputFormat[NullWritable, Writable]]] + + val rdd = sqlContext.sparkContext.hadoopRDD( + conf.asInstanceOf[JobConf], + inputFormatClass, + classOf[NullWritable], + classOf[Writable] + ).asInstanceOf[HadoopRDD[NullWritable, Writable]] + + val wrappedConf = new SerializableWritable(conf) + + rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + fillObject( + split.getPath.toString, + wrappedConf.value, + iterator.map(_._2), + attributes.zipWithIndex, + mutableRow) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala deleted file mode 100644 index 2163b0ce70e99..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.orc - -import java.util._ -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.io.orc._ -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.Job -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat - -import org.apache.spark.rdd.{HadoopRDD, RDD} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.hive.HiveShim -import org.apache.spark.sql.sources.Filter -import org.apache.spark.{Logging, SerializableWritable} - -/* Implicit conversions */ -import scala.collection.JavaConversions._ - -private[orc] case class OrcTableScan(attributes: Seq[Attribute], - @transient relation: OrcRelation, - filters: Array[Filter], - inputPaths: Array[String]) extends Logging { - @transient private val sqlContext = relation.sqlContext - - private def addColumnIds( - output: Seq[Attribute], - relation: OrcRelation, - conf: Configuration): Unit = { - val ids = output.map(a => relation.dataSchema.fieldIndex(a.name): Integer) - val (sortedIds, sortedNames) = ids.zip(attributes.map(_.name)).sorted.unzip - HiveShim.appendReadColumns(conf, sortedIds, sortedNames) - } - - private def buildFilter(job: Job, filters: Array[Filter]): Unit = { - if (ORC_FILTER_PUSHDOWN_ENABLED) { - val conf: Configuration = job.getConfiguration - OrcFilters.createFilter(filters).foreach { f => - conf.set(SARG_PUSHDOWN, toKryo(f)) - conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) - } - } - } - - // Transform all given raw `Writable`s into `Row`s. - private def fillObject( - path: String, - conf: Configuration, - iterator: Iterator[org.apache.hadoop.io.Writable], - nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[Row] = { - val deserializer = new OrcSerde - val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) - val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { - case (attr, ordinal) => - soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal - }.unzip - val unwrappers = HadoopTypeConverter.unwrappers(fieldRefs) - // Map each tuple to a row object - iterator.map { value => - val raw = deserializer.deserialize(value) - var i = 0 - while (i < fieldRefs.length) { - val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) - if (fieldValue == null) { - mutableRow.setNullAt(fieldOrdinals(i)) - } else { - unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) - } - i += 1 - } - mutableRow: Row - } - } - - def execute(): RDD[Row] = { - val sc = sqlContext.sparkContext - val job = new Job(sc.hadoopConfiguration) - val conf: Configuration = job.getConfiguration - - buildFilter(job, filters) - addColumnIds(attributes, relation, conf) - FileInputFormat.setInputPaths(job, inputPaths.map(new Path(_)): _*) - - val inputClass = classOf[OrcInputFormat].asInstanceOf[ - Class[_ <: org.apache.hadoop.mapred.InputFormat[NullWritable, Writable]]] - - val rdd = sc.hadoopRDD(conf.asInstanceOf[JobConf], - inputClass, classOf[NullWritable], classOf[Writable]) - .asInstanceOf[HadoopRDD[NullWritable, Writable]] - val wrappedConf = new SerializableWritable(conf) - val rowRdd: RDD[Row] = rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iter) => - val pathStr = split.getPath.toString - val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) - fillObject(pathStr, wrappedConf.value, iter.map(_._2), attributes.zipWithIndex, mutableRow) - } - rowRdd - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala index b219fbb44ca0d..869c8a5b8f1db 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala @@ -17,14 +17,10 @@ package org.apache.spark.sql.hive -import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.Kryo -import org.apache.commons.codec.binary.Base64 -import org.apache.spark.sql.{SaveMode, DataFrame} +import org.apache.spark.sql.{DataFrame, SaveMode} package object orc { implicit class OrcContext(sqlContext: HiveContext) { - import sqlContext._ @scala.annotation.varargs def orcFile(path: String, paths: String*): DataFrame = { val pathArray: Array[String] = { @@ -34,29 +30,21 @@ package object orc { paths.toArray ++ Array(path) } } + val orcRelation = OrcRelation(pathArray, Map.empty)(sqlContext) sqlContext.baseRelationToDataFrame(orcRelation) } } - implicit class OrcSchemaRDD(dataFrame: DataFrame) { + implicit class OrcDataFrame(dataFrame: DataFrame) { def saveAsOrcFile(path: String, mode: SaveMode = SaveMode.Overwrite): Unit = { - dataFrame.save( - path, - source = classOf[DefaultSource].getCanonicalName, - mode) + dataFrame.save(path, source = classOf[DefaultSource].getCanonicalName, mode) } } // Flags for orc copression, predicates pushdown, etc. val orcDefaultCompressVar = "hive.exec.orc.default.compress" - var ORC_FILTER_PUSHDOWN_ENABLED = true - val SARG_PUSHDOWN = "sarg.pushdown" - def toKryo(input: Any): String = { - val out = new Output(4 * 1024, 10 * 1024 * 1024); - new Kryo().writeObject(out, input); - out.close(); - Base64.encodeBase64String(out.toBytes()); - } + // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. + val SARG_PUSHDOWN = "sarg.pushdown" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala new file mode 100644 index 0000000000000..7e326de1335e0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql._ + +private[sql] trait OrcTest extends SQLTestUtils { + protected def hiveContext = sqlContext.asInstanceOf[HiveContext] + + import sqlContext.sparkContext + import sqlContext.implicits._ + + /** + * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withOrcFile[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: String => Unit): Unit = { + withTempPath { file => + sparkContext.parallelize(data).toDF().saveAsOrcFile(file.getCanonicalPath) + f(file.getCanonicalPath) + } + } + + /** + * Writes `data` to a Orc file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Orc file will be deleted after `f` returns. + */ + protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: DataFrame => Unit): Unit = { + import org.apache.spark.sql.hive.orc.OrcContext + withOrcFile(data)(path => f(hiveContext.orcFile(path))) + } + + /** + * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Orc file will be dropped/deleted after `f` returns. + */ + protected def withOrcTable[T <: Product: ClassTag: TypeTag] + (data: Seq[T], tableName: String) + (f: => Unit): Unit = { + withOrcDataFrame(data) { df => + hiveContext.registerDataFrameAsTable(df, tableName) + withTempTable(tableName)(f) + } + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.orc", SaveMode.Overwrite) + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.save(path.getCanonicalPath, "org.apache.spark.sql.orc", SaveMode.Overwrite) + } +} + +class NewOrcQuerySuite extends QueryTest with OrcTest { + override val sqlContext: SQLContext = TestHive + + import sqlContext._ + + test("simple select queries") { + withOrcTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer( + sql("SELECT `_1` FROM t where t.`_1` > 5"), + (6 until 10).map(Row.apply(_))) + + checkAnswer( + sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"), + (0 until 5).map(Row.apply(_))) + } + } + + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), data.map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } + + withOrcTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`") + val queryOutput = selfJoin.queryExecution.analyzed.output + + assertResult(4, "Field count mismatches")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size + } + + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + } + } + + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) + } + } + + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) + } + } + + test("columns only referenced by pushed down filters should remain") { + withOrcTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_))) + } + } + + test("SPARK-5309 strings stored using dictionary compression in orc") { + withOrcTable((0 until 1000).map(i => ("same", "run_" + i / 100, 1)), "t") { + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"), + List(Row("same", "run_5", 100))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala similarity index 94% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 1d8c421b90678..90812b03fd2e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.sources.{FSBasedRelationTest} +import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ - -class FSBasedOrcRelationSuite extends FSBasedRelationTest { +class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[DefaultSource].getCanonicalName import sqlContext._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 31a829a81124d..55d8b8c71d9ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -37,6 +37,7 @@ case class OrcParData(intField: Int, stringField: String) // The data that also includes the partitioning key case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) +// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 475af3d4c94e4..3d52c31eca9f7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -112,15 +112,16 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("orcTest").getCanonicalPath - val range = (0 to 255) - val data = sparkContext.parallelize(range) - .map(x => AllDataTypesWithNonPrimitiveType( - s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 until x), - (0 until x).map(Option(_).filter(_ % 3 == 0)), - (0 until x).map(i => i -> i.toLong).toMap, - (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), - Data((0 until x), Nested(x, s"$x")))) + val range = 0 to 255 + val data = sparkContext.parallelize(range).map { x => + AllDataTypesWithNonPrimitiveType( + s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, + 0 until x, + (0 until x).map(Option(_).filter(_ % 3 == 0)), + (0 until x).map(i => i -> i.toLong).toMap, + (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), + Data(0 until x, Nested(x, s"$x"))) + } data.toDF().saveAsOrcFile(tempDir) checkAnswer( @@ -204,11 +205,11 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { // We only support zlib in hive0.12.0 now test("Default Compression options for writing to an Orcfile") { // TODO: support other compress codec - var tempDir = getTempFilePath("orcTest").getCanonicalPath - val rdd = sparkContext.parallelize((1 to 100)) + val tempDir = getTempFilePath("orcTest").getCanonicalPath + val rdd = sparkContext.parallelize(1 to 100) .map(i => TestRDDEntry(i, s"val_$i")) rdd.toDF().saveAsOrcFile(tempDir) - var actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression + val actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression assert(actualCodec == CompressionKind.ZLIB) Utils.deleteRecursively(new File(tempDir)) } @@ -217,7 +218,7 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { ignore("Other Compression options for writing to an Orcfile - 0.13.1 and above") { TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "SNAPPY") var tempDir = getTempFilePath("orcTest").getCanonicalPath - val rdd = sparkContext.parallelize((1 to 100)) + val rdd = sparkContext.parallelize(1 to 100) .map(i => TestRDDEntry(i, s"val_$i")) rdd.toDF().saveAsOrcFile(tempDir) var actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index f44b3c521e647..082933e0390f3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -23,12 +23,10 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.parquet.ParquetTest +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -// TODO Don't extend ParquetTest -// This test suite extends ParquetTest for some convenient utility methods. These methods should be -// moved to some more general places, maybe QueryTest. -class HadoopFsRelationTest extends QueryTest with ParquetTest { +abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { override val sqlContext: SQLContext = TestHive import sqlContext._ From 128bd3b025dbdafe473a4f1539d7150f153564fa Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 16 May 2015 09:56:51 +0800 Subject: [PATCH 09/12] ORC filter bug fix --- .../main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 9bee4f59b5854..250e73a4dba92 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -88,7 +88,7 @@ private[orc] object OrcFilters extends Logging { .orElse(tryLeft.flatMap(_ => buildSearchArgument(left, builder))) .orElse(tryRight.flatMap(_ => buildSearchArgument(right, builder))) - case And(left, right) => + case Or(left, right) => for { _ <- buildSearchArgument(left, newBuilder) _ <- buildSearchArgument(right, newBuilder) From 21ada225de783b767f60f9721b336344d16e72d9 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 16 May 2015 11:20:25 +0800 Subject: [PATCH 10/12] Adds @since and @Experimental annotations --- .../apache/spark/sql/hive/orc/package.scala | 52 ++++++++++++++----- .../spark/sql/hive/orc/OrcQuerySuite.scala | 8 +-- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala index 869c8a5b8f1db..ad0f65442b914 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala @@ -17,34 +17,58 @@ package org.apache.spark.sql.hive +import org.apache.spark.annotation.Experimental import org.apache.spark.sql.{DataFrame, SaveMode} package object orc { + /** + * ::Experimental:: + * + * Extra ORC file loading functionality on [[HiveContext]] through implicit conversion. + * + * @since 1.4.0 + */ + @Experimental implicit class OrcContext(sqlContext: HiveContext) { + /** + * ::Experimental:: + * + * Loads specified Parquet files, returning the result as a [[DataFrame]]. + * + * @since 1.4.0 + */ + @Experimental @scala.annotation.varargs - def orcFile(path: String, paths: String*): DataFrame = { - val pathArray: Array[String] = { - if (paths.isEmpty) { - Array(path) - } else { - paths.toArray ++ Array(path) - } - } - - val orcRelation = OrcRelation(pathArray, Map.empty)(sqlContext) + def orcFile(paths: String*): DataFrame = { + val orcRelation = OrcRelation(paths.toArray, Map.empty)(sqlContext) sqlContext.baseRelationToDataFrame(orcRelation) } } + /** + * ::Experimental:: + * + * Extra ORC file writing functionality on [[DataFrame]] through implicit conversion + * + * @since 1.4.0 + */ + @Experimental implicit class OrcDataFrame(dataFrame: DataFrame) { + /** + * ::Experimental:: + * + * Saves the contents of this [[DataFrame]] as an ORC file, preserving the schema. Files that + * are written out using this method can be read back in as a [[DataFrame]] using + * [[OrcContext.orcFile()]]. + * + * @since 1.4.0 + */ + @Experimental def saveAsOrcFile(path: String, mode: SaveMode = SaveMode.Overwrite): Unit = { dataFrame.save(path, source = classOf[DefaultSource].getCanonicalName, mode) } } - // Flags for orc copression, predicates pushdown, etc. - val orcDefaultCompressVar = "hive.exec.orc.default.compress" - // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. - val SARG_PUSHDOWN = "sarg.pushdown" + private[orc] val SARG_PUSHDOWN = "sarg.pushdown" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 3d52c31eca9f7..abc4c92d91da8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.orc import java.io.File +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.Row @@ -216,7 +217,8 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { // Following codec is supported in hive-0.13.1, ignore it now ignore("Other Compression options for writing to an Orcfile - 0.13.1 and above") { - TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "SNAPPY") + val conf = TestHive.sparkContext.hadoopConfiguration + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.name(), "SNAPPY") var tempDir = getTempFilePath("orcTest").getCanonicalPath val rdd = sparkContext.parallelize(1 to 100) .map(i => TestRDDEntry(i, s"val_$i")) @@ -225,14 +227,14 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { assert(actualCodec == CompressionKind.SNAPPY) Utils.deleteRecursively(new File(tempDir)) - TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "NONE") + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.name(), "NONE") tempDir = getTempFilePath("orcTest").getCanonicalPath rdd.toDF().saveAsOrcFile(tempDir) actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression assert(actualCodec == CompressionKind.NONE) Utils.deleteRecursively(new File(tempDir)) - TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "LZO") + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.name(), "LZO") tempDir = getTempFilePath("orcTest").getCanonicalPath rdd.toDF().saveAsOrcFile(tempDir) actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression From d4afeed86f2094df8cb6ba509f5a2da22c3bf02b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 16 May 2015 22:19:06 +0800 Subject: [PATCH 11/12] Addresses comments - Migrates to the new DataFrame reader/writer API - Merges HadoopTypeConverter into HiveInspectors - Refactors test suites --- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../org/apache/spark/sql/sources/ddl.scala | 18 +- .../spark/sql/hive/HiveInspectors.scala | 40 ++- .../sql/hive/orc/HadoopTypeConverter.scala | 61 ---- .../spark/sql/hive/orc/OrcRelation.scala | 24 +- .../apache/spark/sql/hive/orc/package.scala | 74 ----- .../spark/sql/hive/orc/NewOrcQuerySuite.scala | 9 +- .../hive/orc/OrcHadoopFsRelationSuite.scala | 4 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 32 +-- .../spark/sql/hive/orc/OrcQuerySuite.scala | 260 ++++++++---------- .../spark/sql/hive/orc/OrcSourceSuite.scala | 174 ++++-------- 11 files changed, 249 insertions(+), 449 deletions(-) delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b1fc18ac3cb54..9f42f0f1f4398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -55,7 +55,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter = { - saveMode.toLowerCase match { + this.mode = saveMode.toLowerCase match { case "overwrite" => SaveMode.Overwrite case "append" => SaveMode.Append case "ignore" => SaveMode.Ignore diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 37a569db311ea..a13ab74852ff3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -188,18 +188,20 @@ private[sql] class DDLParser( private[sql] object ResolvedDataSource { private val builtinSources = Map( - "jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource], - "json" -> classOf[org.apache.spark.sql.json.DefaultSource], - "parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource] + "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource", + "json" -> "org.apache.spark.sql.json.DefaultSource", + "parquet" -> "org.apache.spark.sql.parquet.DefaultSource", + "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource" ) /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider: String): Class[_] = { + val loader = Utils.getContextOrSparkClassLoader + if (builtinSources.contains(provider)) { - return builtinSources(provider) + return loader.loadClass(builtinSources(provider)) } - val loader = Utils.getContextOrSparkClassLoader try { loader.loadClass(provider) } catch { @@ -208,7 +210,11 @@ private[sql] object ResolvedDataSource { loader.loadClass(provider + ".DefaultSource") } catch { case cnf: java.lang.ClassNotFoundException => - sys.error(s"Failed to load class for data source: $provider") + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + sys.error("The ORC data source must be used with Hive support enabled.") + } else { + sys.error(s"Failed to load class for data source: $provider") + } } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 7c7666f6e4b7c..0a694c70e4e5c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} @@ -122,7 +122,7 @@ import scala.collection.JavaConversions._ * even a normal java object (POJO) * UnionObjectInspector: (tag: Int, object data) (TODO: not supported by SparkSQL yet) * - * 3) ConstantObjectInspector: + * 3) ConstantObjectInspector: * Constant object inspector can be either primitive type or Complex type, and it bundles a * constant value as its property, usually the value is created when the constant object inspector * constructed. @@ -133,7 +133,7 @@ import scala.collection.JavaConversions._ } }}} * Hive provides 3 built-in constant object inspectors: - * Primitive Object Inspectors: + * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector * WritableConstantHiveDecimalObjectInspector @@ -147,9 +147,9 @@ import scala.collection.JavaConversions._ * WritableConstantByteObjectInspector * WritableConstantBinaryObjectInspector * WritableConstantDateObjectInspector - * Map Object Inspector: + * Map Object Inspector: * StandardConstantMapObjectInspector - * List Object Inspector: + * List Object Inspector: * StandardConstantListObjectInspector]] * Struct Object Inspector: Hive doesn't provide the built-in constant object inspector for Struct * Union Object Inspector: Hive doesn't provide the built-in constant object inspector for Union @@ -250,9 +250,9 @@ private[hive] trait HiveInspectors { poi.getWritableConstantValue.getHiveDecimal) case poi: WritableConstantTimestampObjectInspector => poi.getWritableConstantValue.getTimestamp.clone() - case poi: WritableConstantIntObjectInspector => + case poi: WritableConstantIntObjectInspector => poi.getWritableConstantValue.get() - case poi: WritableConstantDoubleObjectInspector => + case poi: WritableConstantDoubleObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantBooleanObjectInspector => poi.getWritableConstantValue.get() @@ -306,7 +306,7 @@ private[hive] trait HiveInspectors { // In order to keep backward-compatible, we have to copy the // bytes with old apis val bw = x.getPrimitiveWritableObject(data) - val result = new Array[Byte](bw.getLength()) + val result = new Array[Byte](bw.getLength()) System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength()) result case x: DateObjectInspector if x.preferWritable() => @@ -394,6 +394,30 @@ private[hive] trait HiveInspectors { identity[Any] } + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ + def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit = + field.getFieldObjectInspector match { + case oi: BooleanObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + case oi: ByteObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + case oi: ShortObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + case oi: IntObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + case oi: LongObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + case oi: FloatObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + case oi: DoubleObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi => + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) + } + /** * Converts native catalyst types to the types expected by Hive * @param a the value to be wrapped diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala deleted file mode 100644 index b5b5e56079cc3..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.orc - -import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ - -import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.hive.HiveInspectors - -/** - * We can consolidate TableReader.unwrappers and HiveInspectors.wrapperFor to use - * this class. - * - */ -private[hive] object HadoopTypeConverter extends HiveInspectors { - /** - * Builds specific unwrappers ahead of time according to object inspector - * types to avoid pattern matching and branching costs per row. - */ - def unwrappers(fieldRefs: Seq[StructField]): Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { - _.getFieldObjectInspector match { - case oi: BooleanObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) - case oi: ByteObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) - case oi: ShortObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) - case oi: IntObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) - case oi: LongObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) - case oi: FloatObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) - case oi: DoubleObjectInspector => - (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) - case oi => - (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) - } - } - - /** - * Wraps with Hive types based on object inspector. - */ - def wrappers(oi: ObjectInspector): Any => Any = wrapperFor(oi) -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 3e3c8a9e619d5..9708199f07349 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -34,7 +34,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveShim} +import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} @@ -50,6 +50,10 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { schema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = { + assert( + sqlContext.isInstanceOf[HiveContext], + "The ORC data source can only be used with HiveContext.") + val partitionSpec = partitionColumns.map(PartitionSpec(_, Seq.empty[Partition])) OrcRelation(paths, parameters, schema, partitionSpec)(sqlContext) } @@ -59,7 +63,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil { + extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { private val serializer = { val table = new Properties() @@ -89,7 +93,7 @@ private[orc] class OrcOutputWriter( // Used to convert Catalyst values into Hadoop `Writable`s. private val wrappers = structOI.getAllStructFieldRefs.map { ref => - HadoopTypeConverter.wrappers(ref.getFieldObjectInspector) + wrapperFor(ref.getFieldObjectInspector) }.toArray // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this @@ -190,7 +194,10 @@ private[orc] case class OrcTableScan( attributes: Seq[Attribute], @transient relation: OrcRelation, filters: Array[Filter], - inputPaths: Array[String]) extends Logging { + inputPaths: Array[String]) + extends Logging + with HiveInspectors { + @transient private val sqlContext = relation.sqlContext private def addColumnIds( @@ -215,7 +222,7 @@ private[orc] case class OrcTableScan( case (attr, ordinal) => soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal }.unzip - val unwrappers = HadoopTypeConverter.unwrappers(fieldRefs) + val unwrappers = fieldRefs.map(unwrapperFor) // Map each tuple to a row object iterator.map { value => val raw = deserializer.deserialize(value) @@ -240,7 +247,7 @@ private[orc] case class OrcTableScan( // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { OrcFilters.createFilter(filters).foreach { f => - conf.set(SARG_PUSHDOWN, f.toKryo) + conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) } } @@ -276,3 +283,8 @@ private[orc] case class OrcTableScan( } } } + +private[orc] object OrcTableScan { + // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. + private[orc] val SARG_PUSHDOWN = "sarg.pushdown" +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala deleted file mode 100644 index ad0f65442b914..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{DataFrame, SaveMode} - -package object orc { - /** - * ::Experimental:: - * - * Extra ORC file loading functionality on [[HiveContext]] through implicit conversion. - * - * @since 1.4.0 - */ - @Experimental - implicit class OrcContext(sqlContext: HiveContext) { - /** - * ::Experimental:: - * - * Loads specified Parquet files, returning the result as a [[DataFrame]]. - * - * @since 1.4.0 - */ - @Experimental - @scala.annotation.varargs - def orcFile(paths: String*): DataFrame = { - val orcRelation = OrcRelation(paths.toArray, Map.empty)(sqlContext) - sqlContext.baseRelationToDataFrame(orcRelation) - } - } - - /** - * ::Experimental:: - * - * Extra ORC file writing functionality on [[DataFrame]] through implicit conversion - * - * @since 1.4.0 - */ - @Experimental - implicit class OrcDataFrame(dataFrame: DataFrame) { - /** - * ::Experimental:: - * - * Saves the contents of this [[DataFrame]] as an ORC file, preserving the schema. Files that - * are written out using this method can be read back in as a [[DataFrame]] using - * [[OrcContext.orcFile()]]. - * - * @since 1.4.0 - */ - @Experimental - def saveAsOrcFile(path: String, mode: SaveMode = SaveMode.Overwrite): Unit = { - dataFrame.save(path, source = classOf[DefaultSource].getCanonicalName, mode) - } - } - - // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. - private[orc] val SARG_PUSHDOWN = "sarg.pushdown" -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala index 7e326de1335e0..ad2fad05188de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala @@ -41,7 +41,7 @@ private[sql] trait OrcTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().saveAsOrcFile(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().write.format("orc").save(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -53,8 +53,7 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - import org.apache.spark.sql.hive.orc.OrcContext - withOrcFile(data)(path => f(hiveContext.orcFile(path))) + withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path))) } /** @@ -73,12 +72,12 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.orc", SaveMode.Overwrite) + data.toDF().write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) } protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.save(path.getCanonicalPath, "org.apache.spark.sql.orc", SaveMode.Overwrite) + df.write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 90812b03fd2e6..080af5bb23c16 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -40,7 +40,9 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") - .saveAsOrcFile(partitionDir.toString) + .write + .format("orc") + .save(partitionDir.toString) } val dataSchemaWithPartition = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 55d8b8c71d9ef..88c99e35260d9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -48,13 +48,13 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().saveAsOrcFile(path.getCanonicalPath) + data.toDF().write.format("orc").mode("overwrite").save(path.getCanonicalPath) } def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.saveAsOrcFile(path.getCanonicalPath) + df.write.format("orc").mode("overwrite").save(path.getCanonicalPath) } protected def withTempTable(tableName: String)(f: => Unit): Unit = { @@ -89,7 +89,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - TestHive.orcFile(base.getCanonicalPath).registerTempTable("t") + read.format("orc").load(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -136,7 +136,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - TestHive.orcFile(base.getCanonicalPath).registerTempTable("t") + read.format("orc").load(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -185,13 +185,11 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val orcRelation = load( - "org.apache.spark.sql.hive.orc.DefaultSource", - Map( - "path" -> base.getCanonicalPath, - ConfVars.DEFAULTPARTITIONNAME.varname -> defaultPartitionName)) - - orcRelation.registerTempTable("t") + read + .format("orc") + .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + .load(base.getCanonicalPath) + .registerTempTable("t") withTempTable("t") { checkAnswer( @@ -230,13 +228,11 @@ class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with Before makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val orcRelation = load( - "org.apache.spark.sql.hive.orc.DefaultSource", - Map( - "path" -> base.getCanonicalPath, - ConfVars.DEFAULTPARTITIONNAME.varname -> defaultPartitionName)) - - orcRelation.registerTempTable("t") + read + .format("orc") + .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + .load(base.getCanonicalPath) + .registerTempTable("t") withTempTable("t") { checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index abc4c92d91da8..338ed7add1995 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -21,43 +21,13 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind +import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} - -case class TestRDDEntry(key: Int, value: String) - -case class NullReflectData( - intField: java.lang.Integer, - longField: java.lang.Long, - floatField: java.lang.Float, - doubleField: java.lang.Double, - booleanField: java.lang.Boolean) - -case class OptionalReflectData( - intField: Option[Int], - longField: Option[Long], - floatField: Option[Float], - doubleField: Option[Double], - booleanField: Option[Boolean]) - -case class Nested(i: Int, s: String) - -case class Data(array: Seq[Int], nested: Nested) - -case class AllDataTypes( - stringField: String, - intField: Int, - longField: Long, - floatField: Float, - doubleField: Double, - shortField: Short, - byteField: Byte, - booleanField: Boolean) case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -72,7 +42,7 @@ case class AllDataTypesWithNonPrimitiveType( arrayContainsNull: Seq[Option[Int]], map: Map[Int, Long], mapValueContainsNull: Map[Int, Option[Long]], - data: Data) + data: (Seq[Int], (Int, String))) case class BinaryData(binaryData: Array[Byte]) @@ -80,7 +50,10 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) -class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { +class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll with OrcTest { + override val sqlContext = TestHive + + import TestHive.read def getTempFilePath(prefix: String, suffix: String = ""): File = { val tempFile = File.createTempFile(prefix, suffix) @@ -88,157 +61,146 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { tempFile } - test("Read/Write All Types") { - val tempDir = getTempFilePath("orcTest").getCanonicalPath - val range = (0 to 255) - val data = sparkContext.parallelize(range) - .map(x => - AllDataTypes(s"$x", x, x.toLong, x.toFloat,x.toDouble, x.toShort, x.toByte, x % 2 == 0)) - data.toDF().saveAsOrcFile(tempDir) + test("Read/write All Types") { + val data = (0 to 255).map { i => + (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0) + } + + withOrcFile(data) { file => checkAnswer( - TestHive.orcFile(tempDir), - data.toDF().collect().toSeq) - Utils.deleteRecursively(new File(tempDir)) + read.format("orc").load(file), + data.toDF().collect()) } + } - test("read/write binary data") { - val tempDir = getTempFilePath("orcTest").getCanonicalPath - sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil) - .toDF().saveAsOrcFile(tempDir) - TestHive.orcFile(tempDir) - .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) - .collect().toSeq == Seq("test") - Utils.deleteRecursively(new File(tempDir)) + test("Read/write binary data") { + withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => + val bytes = read.format("orc").load(file).head().getAs[Array[Byte]](0) + assert(new String(bytes, "utf8") === "test") } + } - test("Read/Write All Types with non-primitive type") { - val tempDir = getTempFilePath("orcTest").getCanonicalPath - val range = 0 to 255 - val data = sparkContext.parallelize(range).map { x => - AllDataTypesWithNonPrimitiveType( - s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - 0 until x, - (0 until x).map(Option(_).filter(_ % 3 == 0)), - (0 until x).map(i => i -> i.toLong).toMap, - (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), - Data(0 until x, Nested(x, s"$x"))) - } - data.toDF().saveAsOrcFile(tempDir) + test("Read/write all types with non-primitive type") { + val data = (0 to 255).map { i => + AllDataTypesWithNonPrimitiveType( + s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, + 0 until i, + (0 until i).map(Option(_).filter(_ % 3 == 0)), + (0 until i).map(i => i -> i.toLong).toMap, + (0 until i).map(i => i -> Option(i.toLong)).toMap + (i -> None), + (0 until i, (i, s"$i"))) + } + withOrcFile(data) { file => checkAnswer( - TestHive.orcFile(tempDir), - data.toDF().collect().toSeq) - Utils.deleteRecursively(new File(tempDir)) + read.format("orc").load(file), + data.toDF().collect()) } + } - test("Creating case class RDD table") { - sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - .toDF().registerTempTable("tmp") - val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) - var counter = 1 - rdd.foreach { - // '===' does not like string comparison? - row: Row => { - assert(row.getString(1).equals(s"val_$counter"), - s"row $counter value ${row.getString(1)} does not match val_$counter") - counter = counter + 1 - } - } + test("Creating case class RDD table") { + val data = (1 to 100).map(i => (i, s"val_$i")) + sparkContext.parallelize(data).toDF().registerTempTable("t") + withTempTable("t") { + checkAnswer(sql("SELECT * FROM t"), data.toDF().collect()) } + } - test("Simple selection form orc table") { - val tempDir = getTempFilePath("orcTest").getCanonicalPath - val data = sparkContext.parallelize((1 to 10)) - .map(i => Person(s"name_$i", i, (0 until 2).map{ m=> - Contact(s"contact_$m", s"phone_$m") })) - data.toDF().saveAsOrcFile(tempDir) - val f = TestHive.orcFile(tempDir) - f.registerTempTable("tmp") + test("Simple selection form ORC table") { + val data = (1 to 10).map { i => + Person(s"name_$i", i, (0 to 1).map { m => Contact(s"contact_$m", s"phone_$m") }) + } + withOrcTable(data, "t") { // ppd: // leaf-0 = (LESS_THAN_EQUALS age 5) // expr = leaf-0 - var rdd = sql("SELECT name FROM tmp where age <= 5") - assert(rdd.count() == 5) + assert(sql("SELECT name FROM t WHERE age <= 5").count() === 5) // ppd: // leaf-0 = (LESS_THAN_EQUALS age 5) // expr = (not leaf-0) - rdd = sql("SELECT name, contacts FROM tmp where age > 5") - assert(rdd.count() == 5) - var contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) - assert(contacts.count() == 10) + assertResult(10) { + sql("SELECT name, contacts FROM t where age > 5") + .flatMap(_.getAs[Seq[_]]("contacts")) + .count() + } // ppd: // leaf-0 = (LESS_THAN_EQUALS age 5) // leaf-1 = (LESS_THAN age 8) // expr = (and (not leaf-0) leaf-1) - rdd = sql("SELECT name, contacts FROM tmp where age > 5 and age < 8") - assert(rdd.count() == 2) - contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) - assert(contacts.count() == 4) + { + val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") + assert(df.count() === 2) + assertResult(4) { + df.flatMap(_.getAs[Seq[_]]("contacts")).count() + } + } // ppd: // leaf-0 = (LESS_THAN age 2) // leaf-1 = (LESS_THAN_EQUALS age 8) // expr = (or leaf-0 (not leaf-1)) - rdd = sql("SELECT name, contacts FROM tmp where age < 2 or age > 8") - assert(rdd.count() == 3) - contacts = rdd.flatMap(t=>t(1).asInstanceOf[Seq[_]]) - assert(contacts.count() == 6) + { + val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") + assert(df.count() === 3) + assertResult(6) { + df.flatMap(_.getAs[Seq[_]]("contacts")).count() + } + } + } + } + test("save and load case class RDD with `None`s as orc") { + val data = ( + None: Option[Int], + None: Option[Long], + None: Option[Float], + None: Option[Double], + None: Option[Boolean] + ) :: Nil - Utils.deleteRecursively(new File(tempDir)) + withOrcFile(data) { file => + checkAnswer( + read.format("orc").load(file), + Row(Seq.fill(5)(null): _*)) } + } - test("save and load case class RDD with Nones as orc") { - val data = OptionalReflectData(None, None, None, None, None) - val rdd = sparkContext.parallelize(data :: Nil) - val tempDir = getTempFilePath("orcTest").getCanonicalPath - rdd.toDF().saveAsOrcFile(tempDir) - val readFile = TestHive.orcFile(tempDir) - val rdd_saved = readFile.collect() - assert(rdd_saved(0).toSeq === Seq.fill(5)(null)) - Utils.deleteRecursively(new File(tempDir)) + // We only support zlib in Hive 0.12.0 now + test("Default compression options for writing to an ORC file") { + withOrcFile((1 to 100).map(i => (i, s"val_$i"))) { file => + assertResult(CompressionKind.ZLIB) { + OrcFileOperator.getFileReader(file).getCompression + } } + } + + // Following codec is supported in hive-0.13.1, ignore it now + ignore("Other compression options for writing to an ORC file - 0.13.1 and above") { + val data = (1 to 100).map(i => (i, s"val_$i")) + val conf = sparkContext.hadoopConfiguration - // We only support zlib in hive0.12.0 now - test("Default Compression options for writing to an Orcfile") { - // TODO: support other compress codec - val tempDir = getTempFilePath("orcTest").getCanonicalPath - val rdd = sparkContext.parallelize(1 to 100) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.toDF().saveAsOrcFile(tempDir) - val actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression - assert(actualCodec == CompressionKind.ZLIB) - Utils.deleteRecursively(new File(tempDir)) + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "SNAPPY") + withOrcFile(data) { file => + assertResult(CompressionKind.SNAPPY) { + OrcFileOperator.getFileReader(file).getCompression + } } - // Following codec is supported in hive-0.13.1, ignore it now - ignore("Other Compression options for writing to an Orcfile - 0.13.1 and above") { - val conf = TestHive.sparkContext.hadoopConfiguration - conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.name(), "SNAPPY") - var tempDir = getTempFilePath("orcTest").getCanonicalPath - val rdd = sparkContext.parallelize(1 to 100) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.toDF().saveAsOrcFile(tempDir) - var actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression - assert(actualCodec == CompressionKind.SNAPPY) - Utils.deleteRecursively(new File(tempDir)) - - conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.name(), "NONE") - tempDir = getTempFilePath("orcTest").getCanonicalPath - rdd.toDF().saveAsOrcFile(tempDir) - actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression - assert(actualCodec == CompressionKind.NONE) - Utils.deleteRecursively(new File(tempDir)) - - conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.name(), "LZO") - tempDir = getTempFilePath("orcTest").getCanonicalPath - rdd.toDF().saveAsOrcFile(tempDir) - actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression - assert(actualCodec == CompressionKind.LZO) - Utils.deleteRecursively(new File(tempDir)) + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "NONE") + withOrcFile(data) { file => + assertResult(CompressionKind.NONE) { + OrcFileOperator.getFileReader(file).getCompression + } } + + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "LZO") + withOrcFile(data) { file => + assertResult(CompressionKind.LZO) { + OrcFileOperator.getFileReader(file).getCompression + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index f86750bcfb6d4..82e08caf46457 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.hive.orc import java.io.File + import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.{QueryTest, Row} case class OrcData(intField: Int, stringField: String) @@ -42,25 +43,25 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { orcTableDir.mkdir() import org.apache.spark.sql.hive.test.TestHive.implicits._ - (sparkContext + sparkContext .makeRDD(1 to 10) - .map(i => OrcData(i, s"part-$i"))) - .toDF.registerTempTable(s"orc_temp_table") - - sql(s""" - create external table normal_orc - ( - intField INT, - stringField STRING - ) - STORED AS orc - location '${orcTableDir.getCanonicalPath}' - """) + .map(i => OrcData(i, s"part-$i")) + .toDF() + .registerTempTable(s"orc_temp_table") sql( - s"""insert into table normal_orc - select intField, stringField from orc_temp_table""") + s"""CREATE EXTERNAL TABLE normal_orc( + | intField INT, + | stringField STRING + |) + |STORED AS ORC + |LOCATION '${orcTableAsDir.getCanonicalPath}' + """.stripMargin) + sql( + s"""INSERT INTO TABLE normal_orc + |SELECT intField, stringField FROM orc_temp_table + """.stripMargin) } override def afterAll(): Unit = { @@ -73,41 +74,15 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM normal_orc_source"), - Row(1, "part-1") :: - Row(2, "part-2") :: - Row(3, "part-3") :: - Row(4, "part-4") :: - Row(5, "part-5") :: - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(10, "part-10") :: Nil - ) + (1 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT * FROM normal_orc_source where intField > 5"), - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(10, "part-10") :: Nil - ) + (6 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - sql("SELECT count(intField), stringField FROM normal_orc_source group by stringField"), - Row(1, "part-1") :: - Row(1, "part-2") :: - Row(1, "part-3") :: - Row(1, "part-4") :: - Row(1, "part-5") :: - Row(1, "part-6") :: - Row(1, "part-7") :: - Row(1, "part-8") :: - Row(1, "part-9") :: - Row(1, "part-10") :: Nil - ) - + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + (1 to 10).map(i => Row(1, s"part-$i"))) } test("create temporary orc table as") { @@ -115,76 +90,36 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM normal_orc_source"), - Row(1, "part-1") :: - Row(2, "part-2") :: - Row(3, "part-3") :: - Row(4, "part-4") :: - Row(5, "part-5") :: - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(10, "part-10") :: Nil - ) + (1 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - sql("SELECT * FROM normal_orc_source where intField > 5"), - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(10, "part-10") :: Nil - ) + sql("SELECT * FROM normal_orc_source WHERE intField > 5"), + (6 to 10).map(i => Row(i, s"part-$i"))) checkAnswer( - sql("SELECT count(intField), stringField FROM normal_orc_source group by stringField"), - Row(1, "part-1") :: - Row(1, "part-2") :: - Row(1, "part-3") :: - Row(1, "part-4") :: - Row(1, "part-5") :: - Row(1, "part-6") :: - Row(1, "part-7") :: - Row(1, "part-8") :: - Row(1, "part-9") :: - Row(1, "part-10") :: Nil - ) - + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + (1 to 10).map(i => Row(1, s"part-$i"))) } test("appending insert") { - sql("insert into table normal_orc_source select * from orc_temp_table where intField > 5") + sql("INSERT INTO TABLE normal_orc_source SELECT * FROM orc_temp_table WHERE intField > 5") + checkAnswer( - sql("select * from normal_orc_source"), - Row(1, "part-1") :: - Row(2, "part-2") :: - Row(3, "part-3") :: - Row(4, "part-4") :: - Row(5, "part-5") :: - Row(6, "part-6") :: - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(9, "part-9") :: - Row(10, "part-10") :: - Row(10, "part-10") :: Nil - ) + sql("SELECT * FROM normal_orc_source"), + (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i => + Seq.fill(2)(Row(i, s"part-$i")) + }) } test("overwrite insert") { - sql("insert overwrite table normal_orc_as_source select * " + - "from orc_temp_table where intField > 5") + sql( + """INSERT OVERWRITE TABLE normal_orc_as_source + |SELECT * FROM orc_temp_table WHERE intField > 5 + """.stripMargin) + checkAnswer( - sql("select * from normal_orc_as_source"), - Row(6, "part-6") :: - Row(7, "part-7") :: - Row(8, "part-8") :: - Row(9, "part-9") :: - Row(10, "part-10") :: Nil - ) + sql("SELECT * FROM normal_orc_as_source"), + (6 to 10).map(i => Row(i, s"part-$i"))) } } @@ -192,21 +127,20 @@ class OrcSourceSuite extends OrcSuite { override def beforeAll(): Unit = { super.beforeAll() - sql( s""" - create temporary table normal_orc_source - USING org.apache.spark.sql.hive.orc - OPTIONS ( - path '${new File(orcTableDir.getAbsolutePath).getCanonicalPath}' - ) - """) - - sql( s""" - create temporary table normal_orc_as_source - USING org.apache.spark.sql.hive.orc - OPTIONS ( - path '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' - ) - as select * from orc_temp_table - """) + sql( + s"""CREATE TEMPORARY TABLE normal_orc_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + |) + """.stripMargin) + + sql( + s"""CREATE TEMPORARY TABLE normal_orc_as_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + |) + """.stripMargin) } } From 55ecd9641838ce4af364005917dcd1d0ffadc3b6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 17 May 2015 12:32:15 +0800 Subject: [PATCH 12/12] Reorganizes ORC test suites --- .../spark/sql/hive/orc/NewOrcQuerySuite.scala | 176 ------------------ .../spark/sql/hive/orc/OrcQuerySuite.scala | 88 +++++++++ .../apache/spark/sql/hive/orc/OrcTest.scala | 82 ++++++++ 3 files changed, 170 insertions(+), 176 deletions(-) delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala deleted file mode 100644 index ad2fad05188de..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.orc - -import java.io.File - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql._ - -private[sql] trait OrcTest extends SQLTestUtils { - protected def hiveContext = sqlContext.asInstanceOf[HiveContext] - - import sqlContext.sparkContext - import sqlContext.implicits._ - - /** - * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` - * returns. - */ - protected def withOrcFile[T <: Product: ClassTag: TypeTag] - (data: Seq[T]) - (f: String => Unit): Unit = { - withTempPath { file => - sparkContext.parallelize(data).toDF().write.format("orc").save(file.getCanonicalPath) - f(file.getCanonicalPath) - } - } - - /** - * Writes `data` to a Orc file and reads it back as a [[DataFrame]], - * which is then passed to `f`. The Orc file will be deleted after `f` returns. - */ - protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] - (data: Seq[T]) - (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path))) - } - - /** - * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a - * temporary table named `tableName`, then call `f`. The temporary table together with the - * Orc file will be dropped/deleted after `f` returns. - */ - protected def withOrcTable[T <: Product: ClassTag: TypeTag] - (data: Seq[T], tableName: String) - (f: => Unit): Unit = { - withOrcDataFrame(data) { df => - hiveContext.registerDataFrameAsTable(df, tableName) - withTempTable(tableName)(f) - } - } - - protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( - data: Seq[T], path: File): Unit = { - data.toDF().write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) - } - - protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( - df: DataFrame, path: File): Unit = { - df.write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) - } -} - -class NewOrcQuerySuite extends QueryTest with OrcTest { - override val sqlContext: SQLContext = TestHive - - import sqlContext._ - - test("simple select queries") { - withOrcTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer( - sql("SELECT `_1` FROM t where t.`_1` > 5"), - (6 until 10).map(Row.apply(_))) - - checkAnswer( - sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"), - (0 until 5).map(Row.apply(_))) - } - } - - test("appending") { - val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") - withOrcTable(data, "t") { - sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) - } - catalog.unregisterTable(Seq("tmp")) - } - - test("overwriting") { - val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") - withOrcTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), data.map(Row.fromTuple)) - } - catalog.unregisterTable(Seq("tmp")) - } - - test("self-join") { - // 4 rows, cells of column 1 of row 2 and row 4 are null - val data = (1 to 4).map { i => - val maybeInt = if (i % 2 == 0) None else Some(i) - (maybeInt, i.toString) - } - - withOrcTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`") - val queryOutput = selfJoin.queryExecution.analyzed.output - - assertResult(4, "Field count mismatches")(queryOutput.size) - assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { - queryOutput.filter(_.name == "_1").map(_.exprId).size - } - - checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) - } - } - - test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) - withOrcTable(data, "t") { - checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { - case Tuple1((_, Seq(string))) => Row(string) - }) - } - } - - test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) - withOrcTable(data, "t") { - checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { - case Tuple1(Seq((_, string))) => Row(string) - }) - } - } - - test("columns only referenced by pushed down filters should remain") { - withOrcTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_))) - } - } - - test("SPARK-5309 strings stored using dictionary compression in orc") { - withOrcTable((0 until 1000).map(i => ("same", "run_" + i / 100, 1)), "t") { - checkAnswer( - sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"), - (0 until 10).map(i => Row("same", "run_" + i, 100))) - - checkAnswer( - sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"), - List(Row("same", "run_5", 100))) - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 338ed7add1995..cdd6e705f4a2c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -203,4 +203,92 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll w } } } + + test("simple select queries") { + withOrcTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer( + sql("SELECT `_1` FROM t where t.`_1` > 5"), + (6 until 10).map(Row.apply(_))) + + checkAnswer( + sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"), + (0 until 5).map(Row.apply(_))) + } + } + + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), data.map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } + + withOrcTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`") + val queryOutput = selfJoin.queryExecution.analyzed.output + + assertResult(4, "Field count mismatches")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size + } + + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + } + } + + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) + } + } + + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) + } + } + + test("columns only referenced by pushed down filters should remain") { + withOrcTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_))) + } + } + + test("SPARK-5309 strings stored using dictionary compression in orc") { + withOrcTable((0 until 1000).map(i => ("same", "run_" + i / 100, 1)), "t") { + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"), + List(Row("same", "run_5", 100))) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala new file mode 100644 index 0000000000000..750f0b04aaa87 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql._ + +private[sql] trait OrcTest extends SQLTestUtils { + protected def hiveContext = sqlContext.asInstanceOf[HiveContext] + + import sqlContext.sparkContext + import sqlContext.implicits._ + + /** + * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withOrcFile[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: String => Unit): Unit = { + withTempPath { file => + sparkContext.parallelize(data).toDF().write.format("orc").save(file.getCanonicalPath) + f(file.getCanonicalPath) + } + } + + /** + * Writes `data` to a Orc file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Orc file will be deleted after `f` returns. + */ + protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: DataFrame => Unit): Unit = { + withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path))) + } + + /** + * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Orc file will be dropped/deleted after `f` returns. + */ + protected def withOrcTable[T <: Product: ClassTag: TypeTag] + (data: Seq[T], tableName: String) + (f: => Unit): Unit = { + withOrcDataFrame(data) { df => + hiveContext.registerDataFrameAsTable(df, tableName) + withTempTable(tableName)(f) + } + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + } +}