Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.columnarbatch;

import org.apache.gluten.backendsapi.BackendsApiManager;
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators;
import org.apache.gluten.runtime.Runtime;
import org.apache.gluten.runtime.Runtimes;

Expand Down Expand Up @@ -56,6 +57,7 @@ public static void checkNonVeloxBatch(ColumnarBatch batch) {
}

public static ColumnarBatch toVeloxBatch(ColumnarBatch input) {
ColumnarBatches.checkOffloaded(input);
if (ColumnarBatches.isZeroColumnBatch(input)) {
return input;
}
Expand Down Expand Up @@ -86,6 +88,26 @@ public static ColumnarBatch toVeloxBatch(ColumnarBatch input) {
return input;
}

/**
* Check if a columnar batch is in Velox format. If not, convert it to Velox format then return.
* If already in Velox format, return the batch directly.
*
* <p>Should only be used for certain conditions when unable to insert explicit to-Velox
* transitions through query planner.
*
* <p>For example, used by {@link org.apache.spark.sql.execution.ColumnarCachedBatchSerializer} as
* Spark directly calls API ColumnarCachedBatchSerializer#convertColumnarBatchToCachedBatch for
* query plan that returns supportsColumnar=true without generating a cache-write query plan node.
*/
public static ColumnarBatch ensureVeloxBatch(ColumnarBatch input) {
final ColumnarBatch light =
ColumnarBatches.ensureOffloaded(ArrowBufferAllocators.contextInstance(), input);
if (isVeloxBatch(light)) {
return light;
}
return toVeloxBatch(light);
}

/**
* Combine multiple columnar batches horizontally, assuming each of them is already offloaded.
* Otherwise {@link UnsupportedOperationException} will be thrown.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,24 +171,16 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with Logging {
conf: SQLConf): RDD[CachedBatch] = {
input.mapPartitions {
it =>
val lightBatches = it.map {
val veloxBatches = it.map {
/* Native code needs a Velox offloaded batch, making sure to offload
if heavy batch is encountered */
batch =>
val heavy = ColumnarBatches.isHeavyBatch(batch)
if (heavy) {
val offloaded = VeloxColumnarBatches.toVeloxBatch(
ColumnarBatches.offload(ArrowBufferAllocators.contextInstance(), batch))
offloaded
} else {
batch
}
batch => VeloxColumnarBatches.ensureVeloxBatch(batch)
}
new Iterator[CachedBatch] {
override def hasNext: Boolean = lightBatches.hasNext
override def hasNext: Boolean = veloxBatches.hasNext

override def next(): CachedBatch = {
val batch = lightBatches.next()
val batch = veloxBatches.next()
val results =
ColumnarBatchSerializerJniWrapper
.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,33 +67,26 @@ class ArrowCsvScanSuiteV2 extends ArrowCsvScanSuite {
}
}

/** Since https://github.com/apache/incubator-gluten/pull/5850. */
abstract class ArrowCsvScanSuite extends VeloxWholeStageTransformerSuite {
override protected val resourcePath: String = "N/A"
override protected val fileFormat: String = "N/A"

protected val rootPath: String = getClass.getResource("/").getPath

override def beforeAll(): Unit = {
super.beforeAll()
createCsvTables()
}

override def afterAll(): Unit = {
super.afterAll()
}

class ArrowCsvScanWithTableCacheSuite extends ArrowCsvScanSuiteBase {
override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.sql.files.maxPartitionBytes", "1g")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.memory.offHeap.size", "2g")
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set(GlutenConfig.NATIVE_ARROW_READER_ENABLED.key, "true")
.set("spark.sql.sources.useV1SourceList", "csv")
.set(GlutenConfig.COLUMNAR_TABLE_CACHE_ENABLED.key, "true")
}

/**
* Test for GLUTEN-8453: https://github.com/apache/incubator-gluten/issues/8453. To make sure no
* error is thrown when caching an Arrow Java query plan.
*/
test("csv scan v1 with table cache") {
val df = spark.sql("select * from student")
df.cache()
assert(df.collect().length == 3)
}
}

/** Since https://github.com/apache/incubator-gluten/pull/5850. */
abstract class ArrowCsvScanSuite extends ArrowCsvScanSuiteBase {
test("csv scan with option string as null") {
val df = runAndCompare("select * from student_option_str")()
val plan = df.queryExecution.executedPlan
Expand Down Expand Up @@ -152,6 +145,33 @@ abstract class ArrowCsvScanSuite extends VeloxWholeStageTransformerSuite {
val df = runAndCompare("select count(1) from student")()
checkLengthAndPlan(df, 1)
}
}

abstract class ArrowCsvScanSuiteBase extends VeloxWholeStageTransformerSuite {
override protected val resourcePath: String = "N/A"
override protected val fileFormat: String = "N/A"

protected val rootPath: String = getClass.getResource("/").getPath

override def beforeAll(): Unit = {
super.beforeAll()
createCsvTables()
}

override def afterAll(): Unit = {
super.afterAll()
}

override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.sql.files.maxPartitionBytes", "1g")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.memory.offHeap.size", "2g")
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set(GlutenConfig.NATIVE_ARROW_READER_ENABLED.key, "true")
}

private def createCsvTables(): Unit = {
spark.read
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,17 @@ private static BatchType identifyBatchType(ColumnarBatch batch) {
}

/** Heavy batch: Data is readable from JVM and formatted as Arrow data. */
public static boolean isHeavyBatch(ColumnarBatch batch) {
@VisibleForTesting
static boolean isHeavyBatch(ColumnarBatch batch) {
return identifyBatchType(batch) == BatchType.HEAVY;
}

/**
* Light batch: Data is not readable from JVM, a long int handle (which is a pointer usually) is
* used to bind the batch to a native side implementation.
*/
public static boolean isLightBatch(ColumnarBatch batch) {
@VisibleForTesting
static boolean isLightBatch(ColumnarBatch batch) {
return identifyBatchType(batch) == BatchType.LIGHT;
}

Expand Down Expand Up @@ -128,7 +130,7 @@ public static ColumnarBatch select(String backendName, ColumnarBatch batch, int[
* Ensure the input batch is offloaded as native-based columnar batch (See {@link IndicatorVector}
* and {@link PlaceholderVector}). This method will close the input column batch after offloaded.
*/
private static ColumnarBatch ensureOffloaded(BufferAllocator allocator, ColumnarBatch batch) {
static ColumnarBatch ensureOffloaded(BufferAllocator allocator, ColumnarBatch batch) {
if (ColumnarBatches.isLightBatch(batch)) {
return batch;
}
Expand All @@ -140,7 +142,7 @@ private static ColumnarBatch ensureOffloaded(BufferAllocator allocator, Columnar
* take place if loading is required, which means when the input batch is not loaded yet. This
* method will close the input column batch after loaded.
*/
private static ColumnarBatch ensureLoaded(BufferAllocator allocator, ColumnarBatch batch) {
static ColumnarBatch ensureLoaded(BufferAllocator allocator, ColumnarBatch batch) {
if (isHeavyBatch(batch)) {
return batch;
}
Expand Down