diff --git a/backends-velox/src/main/java/org/apache/gluten/columnarbatch/VeloxColumnarBatches.java b/backends-velox/src/main/java/org/apache/gluten/columnarbatch/VeloxColumnarBatches.java index db2d08e31435..33f02be08980 100644 --- a/backends-velox/src/main/java/org/apache/gluten/columnarbatch/VeloxColumnarBatches.java +++ b/backends-velox/src/main/java/org/apache/gluten/columnarbatch/VeloxColumnarBatches.java @@ -17,6 +17,7 @@ package org.apache.gluten.columnarbatch; import org.apache.gluten.backendsapi.BackendsApiManager; +import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators; import org.apache.gluten.runtime.Runtime; import org.apache.gluten.runtime.Runtimes; @@ -56,6 +57,7 @@ public static void checkNonVeloxBatch(ColumnarBatch batch) { } public static ColumnarBatch toVeloxBatch(ColumnarBatch input) { + ColumnarBatches.checkOffloaded(input); if (ColumnarBatches.isZeroColumnBatch(input)) { return input; } @@ -86,6 +88,26 @@ public static ColumnarBatch toVeloxBatch(ColumnarBatch input) { return input; } + /** + * Check if a columnar batch is in Velox format. If not, convert it to Velox format then return. + * If already in Velox format, return the batch directly. + * + *
Should only be used for certain conditions when unable to insert explicit to-Velox + * transitions through query planner. + * + *
For example, used by {@link org.apache.spark.sql.execution.ColumnarCachedBatchSerializer} as + * Spark directly calls API ColumnarCachedBatchSerializer#convertColumnarBatchToCachedBatch for + * query plan that returns supportsColumnar=true without generating a cache-write query plan node. + */ + public static ColumnarBatch ensureVeloxBatch(ColumnarBatch input) { + final ColumnarBatch light = + ColumnarBatches.ensureOffloaded(ArrowBufferAllocators.contextInstance(), input); + if (isVeloxBatch(light)) { + return light; + } + return toVeloxBatch(light); + } + /** * Combine multiple columnar batches horizontally, assuming each of them is already offloaded. * Otherwise {@link UnsupportedOperationException} will be thrown. diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala index 1f9419976f29..16004737ea7f 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala @@ -171,24 +171,16 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with Logging { conf: SQLConf): RDD[CachedBatch] = { input.mapPartitions { it => - val lightBatches = it.map { + val veloxBatches = it.map { /* Native code needs a Velox offloaded batch, making sure to offload if heavy batch is encountered */ - batch => - val heavy = ColumnarBatches.isHeavyBatch(batch) - if (heavy) { - val offloaded = VeloxColumnarBatches.toVeloxBatch( - ColumnarBatches.offload(ArrowBufferAllocators.contextInstance(), batch)) - offloaded - } else { - batch - } + batch => VeloxColumnarBatches.ensureVeloxBatch(batch) } new Iterator[CachedBatch] { - override def hasNext: Boolean = lightBatches.hasNext + override def hasNext: Boolean = veloxBatches.hasNext override def next(): CachedBatch = { - val batch = lightBatches.next() + val batch = veloxBatches.next() val results = ColumnarBatchSerializerJniWrapper .create( diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ArrowCsvScanSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ArrowCsvScanSuite.scala index 374fa543af10..c59936a927c1 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ArrowCsvScanSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ArrowCsvScanSuite.scala @@ -67,33 +67,26 @@ class ArrowCsvScanSuiteV2 extends ArrowCsvScanSuite { } } -/** Since https://github.com/apache/incubator-gluten/pull/5850. */ -abstract class ArrowCsvScanSuite extends VeloxWholeStageTransformerSuite { - override protected val resourcePath: String = "N/A" - override protected val fileFormat: String = "N/A" - - protected val rootPath: String = getClass.getResource("/").getPath - - override def beforeAll(): Unit = { - super.beforeAll() - createCsvTables() - } - - override def afterAll(): Unit = { - super.afterAll() - } - +class ArrowCsvScanWithTableCacheSuite extends ArrowCsvScanSuiteBase { override protected def sparkConf: SparkConf = { super.sparkConf - .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") - .set("spark.sql.files.maxPartitionBytes", "1g") - .set("spark.sql.shuffle.partitions", "1") - .set("spark.memory.offHeap.size", "2g") - .set("spark.unsafe.exceptionOnMemoryLeak", "true") - .set("spark.sql.autoBroadcastJoinThreshold", "-1") - .set(GlutenConfig.NATIVE_ARROW_READER_ENABLED.key, "true") + .set("spark.sql.sources.useV1SourceList", "csv") + .set(GlutenConfig.COLUMNAR_TABLE_CACHE_ENABLED.key, "true") + } + + /** + * Test for GLUTEN-8453: https://github.com/apache/incubator-gluten/issues/8453. To make sure no + * error is thrown when caching an Arrow Java query plan. + */ + test("csv scan v1 with table cache") { + val df = spark.sql("select * from student") + df.cache() + assert(df.collect().length == 3) } +} +/** Since https://github.com/apache/incubator-gluten/pull/5850. */ +abstract class ArrowCsvScanSuite extends ArrowCsvScanSuiteBase { test("csv scan with option string as null") { val df = runAndCompare("select * from student_option_str")() val plan = df.queryExecution.executedPlan @@ -152,6 +145,33 @@ abstract class ArrowCsvScanSuite extends VeloxWholeStageTransformerSuite { val df = runAndCompare("select count(1) from student")() checkLengthAndPlan(df, 1) } +} + +abstract class ArrowCsvScanSuiteBase extends VeloxWholeStageTransformerSuite { + override protected val resourcePath: String = "N/A" + override protected val fileFormat: String = "N/A" + + protected val rootPath: String = getClass.getResource("/").getPath + + override def beforeAll(): Unit = { + super.beforeAll() + createCsvTables() + } + + override def afterAll(): Unit = { + super.afterAll() + } + + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") + .set("spark.sql.files.maxPartitionBytes", "1g") + .set("spark.sql.shuffle.partitions", "1") + .set("spark.memory.offHeap.size", "2g") + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set("spark.sql.autoBroadcastJoinThreshold", "-1") + .set(GlutenConfig.NATIVE_ARROW_READER_ENABLED.key, "true") + } private def createCsvTables(): Unit = { spark.read diff --git a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java index 5114853363bd..156de4e0d84d 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java @@ -85,7 +85,8 @@ private static BatchType identifyBatchType(ColumnarBatch batch) { } /** Heavy batch: Data is readable from JVM and formatted as Arrow data. */ - public static boolean isHeavyBatch(ColumnarBatch batch) { + @VisibleForTesting + static boolean isHeavyBatch(ColumnarBatch batch) { return identifyBatchType(batch) == BatchType.HEAVY; } @@ -93,7 +94,8 @@ public static boolean isHeavyBatch(ColumnarBatch batch) { * Light batch: Data is not readable from JVM, a long int handle (which is a pointer usually) is * used to bind the batch to a native side implementation. */ - public static boolean isLightBatch(ColumnarBatch batch) { + @VisibleForTesting + static boolean isLightBatch(ColumnarBatch batch) { return identifyBatchType(batch) == BatchType.LIGHT; } @@ -128,7 +130,7 @@ public static ColumnarBatch select(String backendName, ColumnarBatch batch, int[ * Ensure the input batch is offloaded as native-based columnar batch (See {@link IndicatorVector} * and {@link PlaceholderVector}). This method will close the input column batch after offloaded. */ - private static ColumnarBatch ensureOffloaded(BufferAllocator allocator, ColumnarBatch batch) { + static ColumnarBatch ensureOffloaded(BufferAllocator allocator, ColumnarBatch batch) { if (ColumnarBatches.isLightBatch(batch)) { return batch; } @@ -140,7 +142,7 @@ private static ColumnarBatch ensureOffloaded(BufferAllocator allocator, Columnar * take place if loading is required, which means when the input batch is not loaded yet. This * method will close the input column batch after loaded. */ - private static ColumnarBatch ensureLoaded(BufferAllocator allocator, ColumnarBatch batch) { + static ColumnarBatch ensureLoaded(BufferAllocator allocator, ColumnarBatch batch) { if (isHeavyBatch(batch)) { return batch; }