From 59dec91c3f10ecf4c70bc834b1575534cdb8f568 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 24 Feb 2016 15:06:46 -0800 Subject: [PATCH 01/10] [SPARK-13255][SQL] Update vectorized reader to directly return ColumnarBatch instead of InternalRows. Currently, the parquet reader returns rows one by one which is bad for performance. This patch updates the reader to directly return ColumnarBatches. This is only enabled with whole stage codegen, which is the only operator currently that is able to consume ColumnarBatches (instead of rows). The current implementation is a bit of a hack to get this to work and we should do more refactoring of these low level interfaces to make this work better. Results: TPCDS: Best/Avg Time(ms) Rate(M/s) Per Row(ns) --------------------------------------------------------------------------------- q55 (before) 8897 / 9265 12.9 77.2 q55 5486 / 5753 21.0 47.6 --- .../parquet/UnsafeRowParquetRecordReader.java | 29 ++++++++-- .../spark/sql/execution/ExistingRDD.scala | 55 ++++++++++++++++--- .../datasources/SqlNewHadoopRDD.scala | 8 ++- .../datasources/parquet/ParquetIOSuite.scala | 8 +-- 4 files changed, 83 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 57dbd7c2ff56f..7d768b165f833 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -37,7 +37,6 @@ import org.apache.parquet.schema.Type; import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; @@ -57,10 +56,14 @@ * * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch. * All of these can be handled efficiently and easily with codegen. + * + * This class can either return InternalRows or ColumnarBatches. With whole stage codegen + * enabled, this class returns ColumnarBatches which offers significant performance gains. + * TODO: make this always return ColumnarBatches. */ -public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase { +public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase { /** - * Batch of unsafe rows that we assemble and the current index we've returned. Everytime this + * Batch of unsafe rows that we assemble and the current index we've returned. Every time this * batch is used up (batchIdx == numBatched), we populated the batch. */ private UnsafeRow[] rows = new UnsafeRow[64]; @@ -115,11 +118,15 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas * code between the path that uses the MR decoders and the vectorized ones. * * TODOs: - * - Implement all the encodings to support vectorized. * - Implement v2 page formats (just make sure we create the correct decoders). */ private ColumnarBatch columnarBatch; + /** + * If true, this class returns batches instead of rows. + */ + private boolean returnColumnarBatch; + /** * The default config on whether columnarBatch should be offheap. */ @@ -169,6 +176,8 @@ public void close() throws IOException { @Override public boolean nextKeyValue() throws IOException, InterruptedException { + if (returnColumnarBatch) return nextBatch(); + if (batchIdx >= numBatched) { if (vectorizedDecode()) { if (!nextBatch()) return false; @@ -181,7 +190,9 @@ public boolean nextKeyValue() throws IOException, InterruptedException { } @Override - public InternalRow getCurrentValue() throws IOException, InterruptedException { + public Object getCurrentValue() throws IOException, InterruptedException { + if (returnColumnarBatch) return columnarBatch; + if (vectorizedDecode()) { return columnarBatch.getRow(batchIdx - 1); } else { @@ -210,6 +221,14 @@ public ColumnarBatch resultBatch(MemoryMode memMode) { return columnarBatch; } + /** + * Can be called before any rows are returned to enable returning columnar batches directly. + */ + public void enableReturningBatches() { + assert(vectorizedDecode()); + returnColumnarBatch = true; + } + /** * Advances to the next batch of rows. Returns false if there are no more. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 2cbe3f2c94202..455a0e75361e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -149,14 +149,55 @@ private[sql] case class PhysicalRDD( ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) + + // The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this + // by looking at the first value of the RDD and then calling the function which will process + // the remaining. It is faster to return batches. + // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know + // here which path to use. Fix this. + + val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" + + val scanBatches = ctx.freshName("processBatches") + ctx.addNewFunction(scanBatches, + s""" + | private void $scanBatches($columnarBatchClz batch) throws java.io.IOException { + | while (true) { + | int numRows = batch.numRows(); + | $numOutputRows.add(numRows); + | for (int i = 0; i < numRows; i++) { + | InternalRow $row = batch.getRow(i); + | ${columns.map(_.code).mkString("\n").trim} + | ${consume(ctx, columns).trim} + | } + | + | if (shouldStop()) return; + | if (!$input.hasNext()) break; + | batch = ($columnarBatchClz)$input.next(); + | } + | }""".stripMargin) + + val scanRows = ctx.freshName("processRows") + ctx.addNewFunction(scanRows, + s""" + | private void $scanRows(InternalRow $row) throws java.io.IOException { + | while (true) { + | $numOutputRows.add(1); + | ${columns.map(_.code).mkString("\n").trim} + | ${consume(ctx, columns).trim} + | if (shouldStop()) return; + | if (!$input.hasNext()) break; + | $row = (InternalRow)$input.next(); + | } + | }""".stripMargin) + s""" - | while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutputRows.add(1); - | ${columns.map(_.code).mkString("\n").trim} - | ${consume(ctx, columns).trim} - | if (shouldStop()) { - | return; + | if ($input.hasNext()) { + | Object firstValue = $input.next(); + | if (firstValue instanceof $columnarBatchClz) { + | $scanBatches(($columnarBatchClz)firstValue); + | } else { + | $scanRows((InternalRow)firstValue); | } | } """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index f4271d165c9bd..c4c7eccab6f69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -102,6 +102,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean protected val enableVectorizedParquetReader: Boolean = sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean + protected val enableWholestageCodegen: Boolean = + sqlContext.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key).toBoolean override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) @@ -179,7 +181,11 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( parquetReader.close() } else { reader = parquetReader.asInstanceOf[RecordReader[Void, V]] - if (enableVectorizedParquetReader) parquetReader.resultBatch() + if (enableVectorizedParquetReader) { + parquetReader.resultBatch() + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly + if (enableWholestageCodegen) parquetReader.enableReturningBatches(); + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index c85eeddc2c6d9..cf8a9fdd46fca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -37,7 +37,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf @@ -683,7 +683,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { reader.initialize(file, null) val result = mutable.ArrayBuffer.empty[(Int, String)] while (reader.nextKeyValue()) { - val row = reader.getCurrentValue + val row = reader.getCurrentValue.asInstanceOf[InternalRow] val v = (row.getInt(0), row.getString(1)) result += v } @@ -700,7 +700,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { reader.initialize(file, ("_2" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String)] while (reader.nextKeyValue()) { - val row = reader.getCurrentValue + val row = reader.getCurrentValue.asInstanceOf[InternalRow] result += row.getString(0) } assert(data.map(_._2) == result) @@ -716,7 +716,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String, Int)] while (reader.nextKeyValue()) { - val row = reader.getCurrentValue + val row = reader.getCurrentValue.asInstanceOf[InternalRow] val v = (row.getString(0), row.getInt(1)) result += v } From 058556c1478bd05604ece49d26f5cac8ffa01638 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 29 Feb 2016 13:32:43 -0800 Subject: [PATCH 02/10] Rebase fixes. --- .../datasources/parquet/ParquetReadBenchmark.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 14dbdf34093e9..93265bae1e6b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -22,8 +22,9 @@ import scala.collection.JavaConverters._ import scala.util.Try import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.SQLContext import org.apache.spark.util.{Benchmark, Utils} /** @@ -101,7 +102,7 @@ object ParquetReadBenchmark { reader.initialize(p, ("id" :: Nil).asJava) while (reader.nextKeyValue()) { - val record = reader.getCurrentValue + val record = reader.getCurrentValue.asInstanceOf[InternalRow] if (!record.isNullAt(0)) sum += record.getInt(0) } reader.close() @@ -209,7 +210,7 @@ object ParquetReadBenchmark { val reader = new UnsafeRowParquetRecordReader reader.initialize(p, null) while (reader.nextKeyValue()) { - val record = reader.getCurrentValue + val record = reader.getCurrentValue.asInstanceOf[InternalRow] if (!record.isNullAt(0)) sum1 += record.getInt(0) if (!record.isNullAt(1)) sum2 += record.getUTF8String(1).numBytes() } From 233057656fa0b09e797f5d1292cad70cb3fd9b8d Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 2 Mar 2016 15:59:49 -0800 Subject: [PATCH 03/10] Fix partition columns. --- .../vectorized/ColumnVectorUtils.java | 54 ++++++++++++++ .../execution/vectorized/ColumnarBatch.java | 9 +++ .../datasources/DataSourceStrategy.scala | 71 ++++++++++++++++--- 3 files changed, 125 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 681ace3387139..e2c88d7b4e802 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -26,9 +26,11 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; /** * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly @@ -36,6 +38,58 @@ * These utilities are mostly used to convert ColumnVectors into other formats. */ public class ColumnVectorUtils { + /** + * Populates the entire `col` with `row[fieldIdx]` + */ + public static void populate(ColumnVector col, InternalRow row, int fieldIdx) { + int capacity = col.capacity; + DataType t = col.dataType(); + + if (row.isNullAt(fieldIdx)) { + col.putNulls(0, capacity); + } else { + if (t == DataTypes.BooleanType) { + col.putBooleans(0, capacity, row.getBoolean(fieldIdx)); + } else if (t == DataTypes.ByteType) { + col.putBytes(0, capacity, row.getByte(fieldIdx)); + } else if (t == DataTypes.ShortType) { + col.putShorts(0, capacity, row.getShort(fieldIdx)); + } else if (t == DataTypes.IntegerType) { + col.putInts(0, capacity, row.getInt(fieldIdx)); + } else if (t == DataTypes.LongType) { + col.putLongs(0, capacity, row.getLong(fieldIdx)); + } else if (t == DataTypes.FloatType) { + col.putFloats(0, capacity, row.getFloat(fieldIdx)); + } else if (t == DataTypes.DoubleType) { + col.putDoubles(0, capacity, row.getDouble(fieldIdx)); + } else if (t == DataTypes.StringType) { + UTF8String v = row.getUTF8String(fieldIdx); + for (int i = 0; i < capacity; i++) { + col.putByteArray(i, v.getBytes()); + } + } else if (t instanceof DecimalType) { + DecimalType dt = (DecimalType)t; + Decimal d = row.getDecimal(fieldIdx, dt.precision(), dt.scale()); + if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) { + col.putLongs(0, capacity, d.toUnscaledLong()); + } else { + final BigInteger integer = d.toJavaBigDecimal().unscaledValue(); + byte[] bytes = integer.toByteArray(); + for (int i = 0; i < capacity; i++) { + col.putByteArray(i, bytes, 0, bytes.length); + } + } + } else if (t instanceof CalendarIntervalType) { + CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t); + col.getChildColumn(0).putInts(0, capacity, c.months); + col.getChildColumn(1).putLongs(0, capacity, c.microseconds); + } else if (t instanceof DateType) { + Date date = (Date)row.get(fieldIdx, t); + col.putInts(0, capacity, DateTimeUtils.fromJavaDate(date)); + } + } + } + /** * Returns the array data as the java primitive array. * For example, an array of IntegerType will return an int[]. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 2a780588384ed..26184d4196e0c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -22,6 +22,7 @@ import org.apache.commons.lang.NotImplementedException; import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.Column; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -315,6 +316,14 @@ public int numValidRows() { */ public ColumnVector column(int ordinal) { return columns[ordinal]; } + /** + * Sets (replaces) the column at `ordinal` with column. This can be used to do very efficient + * projections. + */ + public void setColumn(int ordinal, ColumnVector column) { + columns[ordinal] = column; + } + /** * Returns the row in this batch at `rowId`. Returned row is reused across calls. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index ceb35107bf7d8..8c1261563c958 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,25 +19,26 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.ExecutedCommand +import org.apache.spark.sql.execution.vectorized.{ColumnVectorUtils, ColumnarBatch} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.BitSet +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, TaskContext} /** * A Strategy for planning scans over data sources defined using the sources API. @@ -220,6 +221,42 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { sparkPlan } + // Creates a ColumnarBatch that contains the values for `requiredColumns`. These columns can + // either come from `input` (columns scanned from the data source) or from the partitioning + // values (data from `partitionValues`). This is done *once* per physical partition. When + // the column is from `input`, it just references the same underlying column. When using + // partition columns, the column is populated once. + // TODO: there's probably a cleaner way to do this. + private def projectedColumnBatch( + input: ColumnarBatch, + requiredColumns: Seq[Attribute], + dataColumns: Seq[Attribute], + partitionColumnSchema: StructType, + partitionValues: InternalRow) : ColumnarBatch = { + val result = ColumnarBatch.allocate(StructType.fromAttributes(requiredColumns)) + var resultIdx = 0 + var inputIdx = 0 + + while (resultIdx < requiredColumns.length) { + val attr = requiredColumns(resultIdx) + if (inputIdx < dataColumns.length && requiredColumns(resultIdx) == dataColumns(inputIdx)) { + result.setColumn(resultIdx, input.column(inputIdx)) + inputIdx += 1 + } else { + require(partitionColumnSchema.fields.filter(_.name.equals(attr.name)).length == 1) + var partitionIdx = 0 + partitionColumnSchema.fields.foreach { f => { + if (f.name.equals(attr.name)) { + ColumnVectorUtils.populate(result.column(resultIdx), partitionValues, partitionIdx) + } + partitionIdx += 1 + }} + } + resultIdx += 1 + } + result + } + private def mergeWithPartitionValues( requiredColumns: Seq[Attribute], dataColumns: Seq[Attribute], @@ -239,7 +276,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { } } - val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[InternalRow]) => { + val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[Object]) => { // Note that we can't use an `UnsafeRowJoiner` to replace the following `JoinedRow` and // `UnsafeProjection`. Because the projection may also adjust column order. val mutableJoinedRow = new JoinedRow() @@ -247,9 +284,25 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val unsafeProjection = UnsafeProjection.create(requiredColumns, dataColumns ++ partitionColumns) - iterator.map { unsafeDataRow => - unsafeProjection(mutableJoinedRow(unsafeDataRow, unsafePartitionValues)) - } + // If we are returning batches directly, we need to augment them with the partitioning + // columns. We want to do this without a row by row operation. + var columnBatch: ColumnarBatch = null + + iterator.map { input => { + if (input.isInstanceOf[InternalRow]) { + unsafeProjection(mutableJoinedRow( + input.asInstanceOf[InternalRow], unsafePartitionValues)) + } else { + require(input.isInstanceOf[ColumnarBatch]) + val inputBatch = input.asInstanceOf[ColumnarBatch] + if (columnBatch == null) { + columnBatch = projectedColumnBatch(inputBatch, requiredColumns, + dataColumns, partitionColumnSchema, partitionValues) + } + columnBatch.setNumRows(inputBatch.numRows()) + columnBatch + } + }} } // This is an internal RDD whose call site the user should not be concerned with @@ -257,7 +310,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // the call site may add up. Utils.withDummyCallSite(dataRows.sparkContext) { new MapPartitionsRDD(dataRows, mapPartitionsFunc, preservesPartitioning = false) - } + }.asInstanceOf[RDD[InternalRow]] } else { dataRows } From 42875aca00af97ec2928963f72bb9e7665480623 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 2 Mar 2016 16:30:52 -0800 Subject: [PATCH 04/10] Import order fixes --- .../sql/execution/datasources/DataSourceStrategy.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 8c1261563c958..f7e79718cde72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,26 +19,27 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.ExecutedCommand -import org.apache.spark.sql.execution.vectorized.{ColumnVectorUtils, ColumnarBatch} +import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVectorUtils} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.collection.BitSet import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.util.collection.BitSet /** * A Strategy for planning scans over data sources defined using the sources API. From cab64e5e834b72118286b169f6b0860d934d9310 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 2 Mar 2016 18:04:19 -0800 Subject: [PATCH 05/10] Fix use after free issue. --- .../apache/spark/sql/execution/vectorized/ColumnarBatch.java | 3 +++ .../spark/sql/execution/vectorized/OnHeapColumnVector.java | 3 --- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 26184d4196e0c..18763672c6e84 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -321,6 +321,9 @@ public int numValidRows() { * projections. */ public void setColumn(int ordinal, ColumnVector column) { + if (column instanceof OffHeapColumnVector) { + throw new NotImplementedException("Need to ref count columns."); + } columns[ordinal] = column; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 305e84a86bdc7..03160d1ec36ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -62,9 +62,6 @@ public final long nullsNativeAddress() { @Override public final void close() { - nulls = null; - intData = null; - doubleData = null; } // From f35394c4631eed4413cf6b54455b4839d281b8b7 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Thu, 3 Mar 2016 13:36:16 -0800 Subject: [PATCH 06/10] CR and add partition column benchmark. --- .../vectorized/ColumnVectorUtils.java | 3 +- .../spark/sql/execution/ExistingRDD.scala | 11 ++-- .../datasources/DataSourceStrategy.scala | 17 ++++--- .../parquet/ParquetReadBenchmark.scala | 51 ++++++++++++++++--- 4 files changed, 64 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index e2c88d7b4e802..930c4462a4580 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -64,8 +64,9 @@ public static void populate(ColumnVector col, InternalRow row, int fieldIdx) { col.putDoubles(0, capacity, row.getDouble(fieldIdx)); } else if (t == DataTypes.StringType) { UTF8String v = row.getUTF8String(fieldIdx); + byte[] bytes = v.getBytes(); for (int i = 0; i < capacity; i++) { - col.putByteArray(i, v.getBytes()); + col.putByteArray(i, bytes); } } else if (t instanceof DecimalType) { DecimalType dt = (DecimalType)t; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 455a0e75361e1..c9ab466dda796 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -162,18 +162,21 @@ private[sql] case class PhysicalRDD( ctx.addNewFunction(scanBatches, s""" | private void $scanBatches($columnarBatchClz batch) throws java.io.IOException { + | int batchIdx = 0; | while (true) { | int numRows = batch.numRows(); - | $numOutputRows.add(numRows); - | for (int i = 0; i < numRows; i++) { - | InternalRow $row = batch.getRow(i); + | if (batchIdx == 0) $numOutputRows.add(numRows); + | + | while (batchIdx < numRows) { + | InternalRow $row = batch.getRow(batchIdx++); | ${columns.map(_.code).mkString("\n").trim} | ${consume(ctx, columns).trim} + | if (shouldStop()) return; | } | - | if (shouldStop()) return; | if (!$input.hasNext()) break; | batch = ($columnarBatchClz)$input.next(); + | batchIdx = 0; | } | }""".stripMargin) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index f7e79718cde72..c16a379654d08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -222,12 +222,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { sparkPlan } - // Creates a ColumnarBatch that contains the values for `requiredColumns`. These columns can - // either come from `input` (columns scanned from the data source) or from the partitioning - // values (data from `partitionValues`). This is done *once* per physical partition. When - // the column is from `input`, it just references the same underlying column. When using - // partition columns, the column is populated once. - // TODO: there's probably a cleaner way to do this. + /** + * Creates a ColumnarBatch that contains the values for `requiredColumns`. These columns can + * either come from `input` (columns scanned from the data source) or from the partitioning + * values (data from `partitionValues`). This is done *once* per physical partition. When + * the column is from `input`, it just references the same underlying column. When using + * partition columns, the column is populated once. + * TODO: there's probably a cleaner way to do this. + */ private def projectedColumnBatch( input: ColumnarBatch, requiredColumns: Seq[Attribute], @@ -288,6 +290,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // If we are returning batches directly, we need to augment them with the partitioning // columns. We want to do this without a row by row operation. var columnBatch: ColumnarBatch = null + var firstBatch: ColumnarBatch = null iterator.map { input => { if (input.isInstanceOf[InternalRow]) { @@ -297,9 +300,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { require(input.isInstanceOf[ColumnarBatch]) val inputBatch = input.asInstanceOf[ColumnarBatch] if (columnBatch == null) { + firstBatch = inputBatch columnBatch = projectedColumnBatch(inputBatch, requiredColumns, dataColumns, partitionColumnSchema, partitionValues) } + require(firstBatch == inputBatch, "Reader must return the same batch object.") columnBatch.setNumRows(inputBatch.numRows()) columnBatch } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 93265bae1e6b5..471da0d5e9953 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -95,7 +95,7 @@ object ParquetReadBenchmark { val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray // Driving the parquet reader directly without Spark. - parquetReaderBenchmark.addCase("ParquetReader") { num => + parquetReaderBenchmark.addCase("ParquetReader Non-Vectorized") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => val reader = new UnsafeRowParquetRecordReader @@ -110,7 +110,7 @@ object ParquetReadBenchmark { } // Driving the parquet reader in batch mode directly. - parquetReaderBenchmark.addCase("ParquetReader(Batched)") { num => + parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => val reader = new UnsafeRowParquetRecordReader @@ -133,7 +133,7 @@ object ParquetReadBenchmark { } // Decoding in vectorized but having the reader return rows. - parquetReaderBenchmark.addCase("ParquetReader(Batch -> Row)") { num => + parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => val reader = new UnsafeRowParquetRecordReader @@ -167,7 +167,7 @@ object ParquetReadBenchmark { Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz Parquet Reader Single Int Column Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - ParquetReader 565 / 609 27.8 35.9 1.0X + ParquetReader (Non-vectorized) 565 / 609 27.8 35.9 1.0X ParquetReader(Batched) 165 / 174 95.3 10.5 3.4X ParquetReader(Batch -> Row) 158 / 188 99.3 10.1 3.6X */ @@ -264,9 +264,46 @@ object ParquetReadBenchmark { } } + def partitionTableScanBenchmark(values: Int): Unit = { + withTempPath { dir => + withTempTable("t1", "tempTable") { + sqlContext.range(values).registerTempTable("t1") + sqlContext.sql("select id % 2 as p, cast(id as INT) as id from t1") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + + val benchmark = new Benchmark("Partitioned Table", values) + + benchmark.addCase("Read data column") { iter => + sqlContext.sql("select sum(id) from tempTable").collect + } + + benchmark.addCase("Read partition column") { iter => + sqlContext.sql("select sum(p) from tempTable").collect + } + + benchmark.addCase("Read both columns") { iter => + sqlContext.sql("select sum(p), sum(id) from tempTable").collect + } + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Read data column 751 / 805 20.9 47.8 1.0X + Read partition column 713 / 761 22.1 45.3 1.1X + Read both columns 1004 / 1109 15.7 63.8 0.7X + */ + benchmark.run() + } + } + } + def main(args: Array[String]): Unit = { - intScanBenchmark(1024 * 1024 * 15) - intStringScanBenchmark(1024 * 1024 * 10) - stringDictionaryScanBenchmark(1024 * 1024 * 10) + //intScanBenchmark(1024 * 1024 * 15) + //intStringScanBenchmark(1024 * 1024 * 10) + //stringDictionaryScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + } } From 345031346794614f1a543ad76b706cacc9fbe558 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Thu, 3 Mar 2016 13:47:37 -0800 Subject: [PATCH 07/10] CR --- .../vectorized/ColumnVectorUtils.java | 4 +- .../parquet/ParquetReadBenchmark.scala | 37 +++++++++---------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 930c4462a4580..68f146f7a2622 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -71,7 +71,9 @@ public static void populate(ColumnVector col, InternalRow row, int fieldIdx) { } else if (t instanceof DecimalType) { DecimalType dt = (DecimalType)t; Decimal d = row.getDecimal(fieldIdx, dt.precision(), dt.scale()); - if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) { + if (dt.precision() <= Decimal.MAX_INT_DIGITS()) { + col.putInts(0, capacity, (int)d.toUnscaledLong()); + } else if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) { col.putLongs(0, capacity, d.toUnscaledLong()); } else { final BigInteger integer = d.toJavaBigDecimal().unscaledValue(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 471da0d5e9953..38c3618a82ef9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -157,9 +157,9 @@ object ParquetReadBenchmark { Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz SQL Single Int Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 657 / 778 23.9 41.8 1.0X - SQL Parquet MR 1606 / 1731 9.8 102.1 0.4X - SQL Parquet Non-Vectorized 1133 / 1216 13.9 72.1 0.6X + SQL Parquet Vectorized 215 / 262 73.0 13.7 1.0X + SQL Parquet MR 1946 / 2083 8.1 123.7 0.1X + SQL Parquet Non-Vectorized 1079 / 1213 14.6 68.6 0.2X */ sqlBenchmark.run() @@ -167,9 +167,9 @@ object ParquetReadBenchmark { Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz Parquet Reader Single Int Column Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - ParquetReader (Non-vectorized) 565 / 609 27.8 35.9 1.0X - ParquetReader(Batched) 165 / 174 95.3 10.5 3.4X - ParquetReader(Batch -> Row) 158 / 188 99.3 10.1 3.6X + ParquetReader Non-Vectorized 610 / 737 25.8 38.8 1.0X + ParquetReader Vectorized 123 / 152 127.8 7.8 5.0X + ParquetReader Vectorized -> Row 165 / 180 95.2 10.5 3.7X */ parquetReaderBenchmark.run() } @@ -222,10 +222,10 @@ object ParquetReadBenchmark { Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 1025 / 1180 10.2 97.8 1.0X - SQL Parquet MR 2157 / 2222 4.9 205.7 0.5X - SQL Parquet Non-vectorized 1450 / 1466 7.2 138.3 0.7X - ParquetReader Non-vectorized 1005 / 1022 10.4 95.9 1.0X + SQL Parquet Vectorized 628 / 720 16.7 59.9 1.0X + SQL Parquet MR 1905 / 2239 5.5 181.7 0.3X + SQL Parquet Non-vectorized 1429 / 1732 7.3 136.3 0.4X + ParquetReader Non-vectorized 989 / 1357 10.6 94.3 0.6X */ benchmark.run() } @@ -256,8 +256,8 @@ object ParquetReadBenchmark { Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz String Dictionary: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 578 / 593 18.1 55.1 1.0X - SQL Parquet MR 1021 / 1032 10.3 97.4 0.6X + SQL Parquet Vectorized 329 / 337 31.9 31.4 1.0X + SQL Parquet MR 1131 / 1325 9.3 107.8 0.3X */ benchmark.run() } @@ -290,9 +290,9 @@ object ParquetReadBenchmark { Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Read data column 751 / 805 20.9 47.8 1.0X - Read partition column 713 / 761 22.1 45.3 1.1X - Read both columns 1004 / 1109 15.7 63.8 0.7X + Read data column 191 / 250 82.1 12.2 1.0X + Read partition column 82 / 86 192.4 5.2 2.3X + Read both columns 220 / 248 71.5 14.0 0.9X */ benchmark.run() } @@ -300,10 +300,9 @@ object ParquetReadBenchmark { } def main(args: Array[String]): Unit = { - //intScanBenchmark(1024 * 1024 * 15) - //intStringScanBenchmark(1024 * 1024 * 10) - //stringDictionaryScanBenchmark(1024 * 1024 * 10) + intScanBenchmark(1024 * 1024 * 15) + intStringScanBenchmark(1024 * 1024 * 10) + stringDictionaryScanBenchmark(1024 * 1024 * 10) partitionTableScanBenchmark(1024 * 1024 * 15) - } } From f5f1e2be578ad40daafe25c6cc1b09bb4f8bb71a Mon Sep 17 00:00:00 2001 From: Nong Li Date: Thu, 3 Mar 2016 20:40:14 -0800 Subject: [PATCH 08/10] Fix batching. --- .../spark/sql/execution/ExistingRDD.scala | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index c9ab466dda796..50672beb0b90e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -139,9 +139,14 @@ private[sql] case class PhysicalRDD( // Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen // never requires UnsafeRow as input. override protected def doProduce(ctx: CodegenContext): String = { + val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" val input = ctx.freshName("input") + val idx = ctx.freshName("batchIdx") + val batch = ctx.freshName("batch") // PhysicalRDD always just has one input ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") + ctx.addMutableState("int", idx, s"$idx = 0;") val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") @@ -156,27 +161,28 @@ private[sql] case class PhysicalRDD( // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know // here which path to use. Fix this. - val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" val scanBatches = ctx.freshName("processBatches") ctx.addNewFunction(scanBatches, s""" - | private void $scanBatches($columnarBatchClz batch) throws java.io.IOException { - | int batchIdx = 0; + | private void $scanBatches() throws java.io.IOException { | while (true) { - | int numRows = batch.numRows(); - | if (batchIdx == 0) $numOutputRows.add(numRows); + | int numRows = $batch.numRows(); + | if ($idx == 0) $numOutputRows.add(numRows); | - | while (batchIdx < numRows) { - | InternalRow $row = batch.getRow(batchIdx++); + | while ($idx < numRows) { + | InternalRow $row = $batch.getRow($idx++); | ${columns.map(_.code).mkString("\n").trim} | ${consume(ctx, columns).trim} | if (shouldStop()) return; | } | - | if (!$input.hasNext()) break; - | batch = ($columnarBatchClz)$input.next(); - | batchIdx = 0; + | if (!$input.hasNext()) { + | $batch = null; + | break; + | } + | $batch = ($columnarBatchClz)$input.next(); + | $idx = 0; | } | }""".stripMargin) @@ -195,12 +201,17 @@ private[sql] case class PhysicalRDD( | }""".stripMargin) s""" - | if ($input.hasNext()) { - | Object firstValue = $input.next(); - | if (firstValue instanceof $columnarBatchClz) { - | $scanBatches(($columnarBatchClz)firstValue); + | if ($batch != null || $input.hasNext()) { + | if ($batch == null) { + | Object value = $input.next(); + | if (value instanceof $columnarBatchClz) { + | $batch = ($columnarBatchClz)value; + | $scanBatches(); + | } else { + | $scanRows((InternalRow)value); + | } | } else { - | $scanRows((InternalRow)firstValue); + | $scanBatches(); | } | } """.stripMargin From ed79eee5daeab177c4350f6f111898f0e7339309 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Thu, 3 Mar 2016 23:01:23 -0800 Subject: [PATCH 09/10] Fix test for bucketed tables. --- .../sql/execution/datasources/DataSourceStrategy.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c16a379654d08..69a6d23203b93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -290,7 +290,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // If we are returning batches directly, we need to augment them with the partitioning // columns. We want to do this without a row by row operation. var columnBatch: ColumnarBatch = null - var firstBatch: ColumnarBatch = null + var mergedBatch: ColumnarBatch = null iterator.map { input => { if (input.isInstanceOf[InternalRow]) { @@ -299,12 +299,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { } else { require(input.isInstanceOf[ColumnarBatch]) val inputBatch = input.asInstanceOf[ColumnarBatch] - if (columnBatch == null) { - firstBatch = inputBatch + if (inputBatch != mergedBatch) { + mergedBatch = inputBatch columnBatch = projectedColumnBatch(inputBatch, requiredColumns, dataColumns, partitionColumnSchema, partitionValues) } - require(firstBatch == inputBatch, "Reader must return the same batch object.") columnBatch.setNumRows(inputBatch.numRows()) columnBatch } From 48102e3ba2229d826d35da77b5b6bea4e2107f2b Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 4 Mar 2016 11:18:19 -0800 Subject: [PATCH 10/10] CR --- .../spark/sql/execution/ExistingRDD.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 50672beb0b90e..36e656b8b6abf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -201,17 +201,15 @@ private[sql] case class PhysicalRDD( | }""".stripMargin) s""" - | if ($batch != null || $input.hasNext()) { - | if ($batch == null) { - | Object value = $input.next(); - | if (value instanceof $columnarBatchClz) { - | $batch = ($columnarBatchClz)value; - | $scanBatches(); - | } else { - | $scanRows((InternalRow)value); - | } - | } else { + | if ($batch != null) { + | $scanBatches(); + | } else if ($input.hasNext()) { + | Object value = $input.next(); + | if (value instanceof $columnarBatchClz) { + | $batch = ($columnarBatchClz)value; | $scanBatches(); + | } else { + | $scanRows((InternalRow)value); | } | } """.stripMargin