From 3a082dce9192f2056afbec6f998627f0ce7042b9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 3 Sep 2024 12:50:10 -0700 Subject: [PATCH 1/8] chore: Revise array import to more follow C Data Interface semantics --- .../org/apache/comet/vector/NativeUtil.scala | 30 ++++++-- native/core/src/execution/jni_api.rs | 68 +++++++++++-------- native/core/src/execution/utils.rs | 16 +++++ .../org/apache/comet/CometExecIterator.scala | 13 ++-- .../main/scala/org/apache/comet/Native.scala | 10 +-- .../spark/sql/comet/CometExecUtils.scala | 2 +- .../CometTakeOrderedAndProjectExec.scala | 4 +- .../shuffle/CometShuffleExchangeExec.scala | 1 + .../apache/spark/sql/comet/operators.scala | 14 ++-- 9 files changed, 107 insertions(+), 51 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 751aa6bb0a..554a53e3b0 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -56,6 +56,28 @@ class NativeUtil { */ private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider + /** + * Allocates Arrow structs for the given number of columns. + * + * @param numCols + * the number of columns + * @return + * a pair of arrays containing memory addresses of Arrow arrays and Arrow schemas + */ + def allocateArrowStructs(numCols: Int): (Array[Long], Array[Long]) = { + val arrayAddrs = new Array[Long](numCols) + val schemaAddrs = new Array[Long](numCols) + + (0 until numCols).foreach { index => + val arrowSchema = ArrowSchema.allocateNew(allocator) + val arrowArray = ArrowArray.allocateNew(allocator) + arrayAddrs(index) = arrowArray.memoryAddress() + schemaAddrs(index) = arrowSchema.memoryAddress() + } + + (arrayAddrs, schemaAddrs) + } + /** * Exports a Comet `ColumnarBatch` into a list of memory addresses that can be consumed by the * native execution. @@ -110,12 +132,12 @@ class NativeUtil { * @return * a list of Comet vectors */ - def importVector(arrayAddress: Array[Long]): Seq[CometVector] = { + def importVector(arrayAddrs: Array[Long], schemaAddrs: Array[Long]): Seq[CometVector] = { val arrayVectors = mutable.ArrayBuffer.empty[CometVector] - for (i <- arrayAddress.indices by 2) { - val arrowSchema = ArrowSchema.wrap(arrayAddress(i + 1)) - val arrowArray = ArrowArray.wrap(arrayAddress(i)) + (0 until arrayAddrs.length).foreach { i => + val arrowSchema = ArrowSchema.wrap(schemaAddrs(i)) + val arrowArray = ArrowArray.wrap(arrayAddrs(i)) // Native execution should always have 'useDecimal128' set to true since it doesn't support // other cases. diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 2d99a854d3..4b2225b40c 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -17,10 +17,7 @@ //! Define JNI APIs which can be called from Java/Scala. -use arrow::{ - datatypes::DataType as ArrowDataType, - ffi::{FFI_ArrowArray, FFI_ArrowSchema}, -}; +use arrow::datatypes::DataType as ArrowDataType; use arrow_array::RecordBatch; use datafusion::{ execution::{ @@ -78,8 +75,6 @@ struct ExecutionContext { pub input_sources: Vec>, /// The record batch stream to pull results from pub stream: Option, - /// The FFI arrays. We need to keep them alive here. - pub ffi_arrays: Vec<(Arc, Arc)>, /// Configurations for DF execution pub conf: HashMap, /// The Tokio runtime used for async. @@ -177,7 +172,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( scans: vec![], input_sources, stream: None, - ffi_arrays: vec![], conf: configs, runtime, metrics, @@ -265,14 +259,33 @@ fn parse_bool(conf: &HashMap, name: &str) -> CometResult { } /// Prepares arrow arrays for output. -fn prepare_output( +unsafe fn prepare_output( env: &mut JNIEnv, + array_addrs: jlongArray, + schema_addrs: jlongArray, output_batch: RecordBatch, exec_context: &mut ExecutionContext, -) -> CometResult { +) -> CometResult { + let array_address_array = JLongArray::from_raw(array_addrs); + let num_cols = env.get_array_length(&array_address_array)? as usize; + + let array_addrs = env.get_array_elements(&array_address_array, ReleaseMode::NoCopyBack)?; + let array_addrs = &*array_addrs; + + let schema_address_array = JLongArray::from_raw(schema_addrs); + let schema_addrs = env.get_array_elements(&schema_address_array, ReleaseMode::NoCopyBack)?; + let schema_addrs = &*schema_addrs; + let results = output_batch.columns(); let num_rows = output_batch.num_rows(); + if results.len() != num_cols { + return Err(CometError::Internal(format!( + "Output column count mismatch: expected {num_cols}, got {}", + results.len() + ))); + } + if exec_context.debug_native { // Validate the output arrays. for array in results.iter() { @@ -286,32 +299,22 @@ fn prepare_output( let return_flag = 1; let long_array = env.new_long_array((results.len() * 2) as i32 + 2)?; - env.set_long_array_region(&long_array, 0, &[return_flag, num_rows as jlong])?; - - let mut arrays = vec![]; + env.set_long_array_region(long_array, 0, &[return_flag, num_rows as jlong])?; let mut i = 0; while i < results.len() { let array_ref = results.get(i).ok_or(CometError::IndexOutOfBounds(i))?; - let (array, schema) = array_ref.to_data().to_spark()?; + array_ref + .to_data() + .move_to_spark(array_addrs[i], schema_addrs[i])?; - unsafe { - let arrow_array = Arc::from_raw(array as *const FFI_ArrowArray); - let arrow_schema = Arc::from_raw(schema as *const FFI_ArrowSchema); - arrays.push((arrow_array, arrow_schema)); - } - - env.set_long_array_region(&long_array, (i * 2) as i32 + 2, &[array, schema])?; i += 1; } // Update metrics update_metrics(env, exec_context)?; - // Record the pointer to allocated Arrow Arrays - exec_context.ffi_arrays = arrays; - - Ok(long_array.into_raw()) + Ok(num_rows as jlong) } /// Pull the next input from JVM. Note that we cannot pull input batches in @@ -337,7 +340,9 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( e: JNIEnv, _class: JClass, exec_context: jlong, -) -> jlongArray { + array_addrs: jlongArray, + schema_addrs: jlongArray, +) -> jlong { try_unwrap_or_throw(&e, |mut env| { // Retrieve the query let exec_context = get_execution_context(exec_context); @@ -383,7 +388,13 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( match poll_output { Poll::Ready(Some(output)) => { - return prepare_output(&mut env, output?, exec_context); + return prepare_output( + &mut env, + array_addrs, + schema_addrs, + output?, + exec_context, + ); } Poll::Ready(None) => { // Reaches EOF of output. @@ -399,10 +410,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( } } - let long_array = env.new_long_array(1)?; - env.set_long_array_region(&long_array, 0, &[-1])?; - - return Ok(long_array.into_raw()); + return Ok(-1); } // A poll pending means there are more than one blocking operators, // we don't need go back-forth between JVM/Native. Just keeping polling. diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index cb21391a23..7546726214 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -55,6 +55,9 @@ pub trait SparkArrowConvert { /// Convert Arrow Arrays to C data interface. /// It returns a tuple (ArrowArray address, ArrowSchema address). fn to_spark(&self) -> Result<(i64, i64), ExecutionError>; + + /// Move Arrow Arrays to C data interface. + fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError>; } impl SparkArrowConvert for ArrayData { @@ -96,6 +99,19 @@ impl SparkArrowConvert for ArrayData { Ok((array as i64, schema as i64)) } + + /// Move this ArrowData to pointers of Arrow C data interface. + fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> { + unsafe { std::ptr::replace(array as *mut FFI_ArrowArray, FFI_ArrowArray::new(self)) }; + unsafe { + std::ptr::replace( + schema as *mut FFI_ArrowSchema, + FFI_ArrowSchema::try_from(self.data_type())?, + ) + }; + + Ok(()) + } } /// Converts a slice of bytes to i128. The bytes are serialized in big-endian order by diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index dcdc8ae92f..eea92a1be6 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -43,6 +43,7 @@ import org.apache.comet.vector.NativeUtil class CometExecIterator( val id: Long, inputs: Seq[Iterator[ColumnarBatch]], + numOutputCols: Int, protobufQueryPlan: Array[Byte], nativeMetrics: CometMetricNode) extends Iterator[ColumnarBatch] { @@ -100,18 +101,18 @@ class CometExecIterator( } def getNextBatch(): Option[ColumnarBatch] = { + val (arrayAddrs, schemaAddrs) = nativeUtil.allocateArrowStructs(numOutputCols) + // we execute the native plan each time we need another output batch and this could // result in multiple input batches being processed - val result = nativeLib.executePlan(plan) + val result = nativeLib.executePlan(plan, arrayAddrs, schemaAddrs) - result(0) match { + result match { case -1 => // EOF None - case 1 => - val numRows = result(1) - val addresses = result.slice(2, result.length) - val cometVectors = nativeUtil.importVector(addresses) + case numRows => + val cometVectors = nativeUtil.importVector(arrayAddrs, schemaAddrs) Some(new ColumnarBatch(cometVectors.toArray, numRows.toInt)) case flag => throw new IllegalStateException(s"Invalid native flag: $flag") diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 97ded91b26..03a9dea0c6 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -58,12 +58,14 @@ class Native extends NativeBase { * * @param plan * the address to native query plan. + * @param arrayAddrs + * the addresses of Arrow Array structures + * @param schemaAddrs + * the addresses of Arrow Schema structures * @return - * an array containing: 1) the status flag (1 for normal returned arrays, -1 for end of - * output) 2) (optional) the number of rows if returned flag is 1 3) the addresses of output - * Arrow arrays + * the number of rows, if -1, it means end of the output. */ - @native def executePlan(plan: Long): Array[Long] + @native def executePlan(plan: Long, arrayAddrs: Array[Long], schemaAddrs: Array[Long]): Long /** * Release and drop the native query plan object and context object. diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index 5931920a20..8cc03856c2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -53,7 +53,7 @@ object CometExecUtils { limit: Int): RDD[ColumnarBatch] = { childPlan.mapPartitionsInternal { iter => val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get - CometExec.getCometIterator(Seq(iter), limitOp) + CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index ce40a2ebf9..6220c809da 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -82,7 +82,7 @@ case class CometTakeOrderedAndProjectExec( CometExecUtils .getTopKNativePlan(child.output, sortOrder, child, limit) .get - CometExec.getCometIterator(Seq(iter), topK) + CometExec.getCometIterator(Seq(iter), child.output.length, topK) } } @@ -102,7 +102,7 @@ case class CometTakeOrderedAndProjectExec( val topKAndProjection = CometExecUtils .getProjectionNativePlan(projectList, child.output, sortOrder, child, limit) .get - val it = CometExec.getCometIterator(Seq(iter), topKAndProjection) + val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection) setSubqueries(it.id, this) Option(TaskContext.get()).foreach { context => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 028ba24393..6430a7899f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -487,6 +487,7 @@ class CometShuffleWriteProcessor( val cometIter = CometExec.getCometIterator( Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), + outputAttributes.length, nativePlan, nativeMetrics) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 35ae8ad629..0b98b62c06 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -119,19 +119,21 @@ object CometExec { def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], + numOutputCols: Int, nativePlan: Operator): CometExecIterator = { - getCometIterator(inputs, nativePlan, CometMetricNode(Map.empty)) + getCometIterator(inputs, numOutputCols, nativePlan, CometMetricNode(Map.empty)) } def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], + numOutputCols: Int, nativePlan: Operator, nativeMetrics: CometMetricNode): CometExecIterator = { val outputStream = new ByteArrayOutputStream() nativePlan.writeTo(outputStream) outputStream.close() val bytes = outputStream.toByteArray - new CometExecIterator(newIterId, inputs, bytes, nativeMetrics) + new CometExecIterator(newIterId, inputs, numOutputCols, bytes, nativeMetrics) } /** @@ -213,8 +215,12 @@ abstract class CometNativeExec extends CometExec { val nativeMetrics = CometMetricNode.fromCometPlan(this) def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): CometExecIterator = { - val it = - new CometExecIterator(CometExec.newIterId, inputs, serializedPlanCopy, nativeMetrics) + val it = new CometExecIterator( + CometExec.newIterId, + inputs, + output.length, + serializedPlanCopy, + nativeMetrics) setSubqueries(it.id, this) From 26064614717f7b3d8cc59ad50037f47780e14954 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 3 Sep 2024 13:37:29 -0700 Subject: [PATCH 2/8] more --- native/core/src/execution/jni_api.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 4b2225b40c..c1cc784556 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -296,11 +296,6 @@ unsafe fn prepare_output( } } - let return_flag = 1; - - let long_array = env.new_long_array((results.len() * 2) as i32 + 2)?; - env.set_long_array_region(long_array, 0, &[return_flag, num_rows as jlong])?; - let mut i = 0; while i < results.len() { let array_ref = results.get(i).ok_or(CometError::IndexOutOfBounds(i))?; From 8a23a6ed9a5139acb841010f5fad4a7d37eff4d0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 3 Sep 2024 23:00:22 -0700 Subject: [PATCH 3/8] fix --- .../scala/org/apache/comet/vector/NativeUtil.scala | 2 +- native/core/src/execution/utils.rs | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 554a53e3b0..1def60896e 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -56,7 +56,7 @@ class NativeUtil { */ private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider - /** + /** * Allocates Arrow structs for the given number of columns. * * @param numCols diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 7546726214..57b00be125 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::mem::forget; use std::sync::Arc; use arrow::{ @@ -102,14 +103,20 @@ impl SparkArrowConvert for ArrayData { /// Move this ArrowData to pointers of Arrow C data interface. fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> { - unsafe { std::ptr::replace(array as *mut FFI_ArrowArray, FFI_ArrowArray::new(self)) }; - unsafe { + let jvm_array = + unsafe { std::ptr::replace(array as *mut FFI_ArrowArray, FFI_ArrowArray::new(self)) }; + let jvm_schema = unsafe { std::ptr::replace( schema as *mut FFI_ArrowSchema, FFI_ArrowSchema::try_from(self.data_type())?, ) }; + // Don't deallocate the memory of the ArrowArray and ArrowSchema since they are allocated in Java. + // They will be deallocated in JVM. + forget(jvm_array); + forget(jvm_schema); + Ok(()) } } From 8614192e2c80b2fc8f829d7fe9aca19cc3c15409 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 Sep 2024 09:12:09 -0700 Subject: [PATCH 4/8] For review --- .../org/apache/comet/vector/NativeUtil.scala | 64 ++++++++++++++----- native/core/src/execution/utils.rs | 11 +--- .../org/apache/comet/CometExecIterator.scala | 21 ++---- 3 files changed, 54 insertions(+), 42 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 1def60896e..33af8662fe 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -62,20 +62,20 @@ class NativeUtil { * @param numCols * the number of columns * @return - * a pair of arrays containing memory addresses of Arrow arrays and Arrow schemas + * a pair of Arrow arrays and Arrow schemas */ - def allocateArrowStructs(numCols: Int): (Array[Long], Array[Long]) = { - val arrayAddrs = new Array[Long](numCols) - val schemaAddrs = new Array[Long](numCols) + def allocateArrowStructs(numCols: Int): (Array[ArrowArray], Array[ArrowSchema]) = { + val arrays = new Array[ArrowArray](numCols) + val schemas = new Array[ArrowSchema](numCols) (0 until numCols).foreach { index => val arrowSchema = ArrowSchema.allocateNew(allocator) val arrowArray = ArrowArray.allocateNew(allocator) - arrayAddrs(index) = arrowArray.memoryAddress() - schemaAddrs(index) = arrowSchema.memoryAddress() + arrays(index) = arrowArray + schemas(index) = arrowSchema } - (arrayAddrs, schemaAddrs) + (arrays, schemas) } /** @@ -123,21 +123,54 @@ class NativeUtil { batch.numRows() } + /** + * Gets the next batch from native execution. + * + * @param numOutputCols + * The number of output columns + * @param func + * The function to call to get the next batch + * @return + * The number of row of the next batch, or None if there are no more batches + */ + def getNextBatch( + numOutputCols: Int, + func: (Array[Long], Array[Long]) => Long): Option[ColumnarBatch] = { + val (arrays, schemas) = allocateArrowStructs(numOutputCols) + + val arrayAddrs = arrays.map(_.memoryAddress()) + val schemaAddrs = schemas.map(_.memoryAddress()) + + val result = func(arrayAddrs, schemaAddrs) + + result match { + case -1 => + // EOF + None + case numRows => + val cometVectors = importVector(arrays, schemas) + Some(new ColumnarBatch(cometVectors.toArray, numRows.toInt)) + case flag => + throw new IllegalStateException(s"Invalid native flag: $flag") + } + } + /** * Imports a list of Arrow addresses from native execution, and return a list of Comet vectors. * - * @param arrayAddress - * a list containing paris of Arrow addresses from the native, in the format of (address of - * Arrow array, address of Arrow schema) + * @param arrays + * a list of Arrow array + * @param schemas + * a list of Arrow schema * @return * a list of Comet vectors */ - def importVector(arrayAddrs: Array[Long], schemaAddrs: Array[Long]): Seq[CometVector] = { + def importVector(arrays: Array[ArrowArray], schemas: Array[ArrowSchema]): Seq[CometVector] = { val arrayVectors = mutable.ArrayBuffer.empty[CometVector] - (0 until arrayAddrs.length).foreach { i => - val arrowSchema = ArrowSchema.wrap(schemaAddrs(i)) - val arrowArray = ArrowArray.wrap(arrayAddrs(i)) + (0 until arrays.length).foreach { i => + val arrowSchema = schemas(i) + val arrowArray = arrays(i) // Native execution should always have 'useDecimal128' set to true since it doesn't support // other cases. @@ -145,9 +178,6 @@ class NativeUtil { importer.importVector(arrowArray, arrowSchema, dictionaryProvider), true, dictionaryProvider) - - arrowArray.close() - arrowSchema.close() } arrayVectors.toSeq } diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 57b00be125..7546726214 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::mem::forget; use std::sync::Arc; use arrow::{ @@ -103,20 +102,14 @@ impl SparkArrowConvert for ArrayData { /// Move this ArrowData to pointers of Arrow C data interface. fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> { - let jvm_array = - unsafe { std::ptr::replace(array as *mut FFI_ArrowArray, FFI_ArrowArray::new(self)) }; - let jvm_schema = unsafe { + unsafe { std::ptr::replace(array as *mut FFI_ArrowArray, FFI_ArrowArray::new(self)) }; + unsafe { std::ptr::replace( schema as *mut FFI_ArrowSchema, FFI_ArrowSchema::try_from(self.data_type())?, ) }; - // Don't deallocate the memory of the ArrowArray and ArrowSchema since they are allocated in Java. - // They will be deallocated in JVM. - forget(jvm_array); - forget(jvm_schema); - Ok(()) } } diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index eea92a1be6..07dd80c39e 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -101,22 +101,11 @@ class CometExecIterator( } def getNextBatch(): Option[ColumnarBatch] = { - val (arrayAddrs, schemaAddrs) = nativeUtil.allocateArrowStructs(numOutputCols) - - // we execute the native plan each time we need another output batch and this could - // result in multiple input batches being processed - val result = nativeLib.executePlan(plan, arrayAddrs, schemaAddrs) - - result match { - case -1 => - // EOF - None - case numRows => - val cometVectors = nativeUtil.importVector(arrayAddrs, schemaAddrs) - Some(new ColumnarBatch(cometVectors.toArray, numRows.toInt)) - case flag => - throw new IllegalStateException(s"Invalid native flag: $flag") - } + nativeUtil.getNextBatch( + numOutputCols, + (arrayAddrs, schemaAddrs) => { + nativeLib.executePlan(plan, arrayAddrs, schemaAddrs) + }) } override def hasNext: Boolean = { From ddb70b1f73207e718e8bf6e22ab7e23de14aed06 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 Sep 2024 11:37:45 -0700 Subject: [PATCH 5/8] Try --- native/core/src/execution/utils.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 7546726214..15f092ea75 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -102,14 +102,17 @@ impl SparkArrowConvert for ArrayData { /// Move this ArrowData to pointers of Arrow C data interface. fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> { - unsafe { std::ptr::replace(array as *mut FFI_ArrowArray, FFI_ArrowArray::new(self)) }; - unsafe { + let jvm_array = unsafe { std::ptr::replace(array as *mut FFI_ArrowArray, FFI_ArrowArray::new(self)) }; + let jvm_schema = unsafe { std::ptr::replace( schema as *mut FFI_ArrowSchema, FFI_ArrowSchema::try_from(self.data_type())?, ) }; + std::mem::forget(jvm_array); + std::mem::forget(jvm_schema); + Ok(()) } } From 4beb19634ac5821c08357cf79c651a459a6363e5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 Sep 2024 11:47:49 -0700 Subject: [PATCH 6/8] check alignment --- native/core/src/execution/utils.rs | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 15f092ea75..ca48c94d95 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -102,13 +102,21 @@ impl SparkArrowConvert for ArrayData { /// Move this ArrowData to pointers of Arrow C data interface. fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> { - let jvm_array = unsafe { std::ptr::replace(array as *mut FFI_ArrowArray, FFI_ArrowArray::new(self)) }; - let jvm_schema = unsafe { - std::ptr::replace( - schema as *mut FFI_ArrowSchema, - FFI_ArrowSchema::try_from(self.data_type())?, - ) - }; + let array_ptr = array as *mut FFI_ArrowArray; + let schema_ptr = schema as *mut FFI_ArrowSchema; + + let array_align = std::mem::align_of::(); + let schema_align = std::mem::align_of::(); + + if array_ptr.align_offset(array_align) != 0 || schema_ptr.align_offset(schema_align) != 0 { + return Err(ExecutionError::ArrowError( + "Pointer alignment is not correct".to_string(), + )); + } + + let jvm_array = unsafe { std::ptr::replace(array_ptr, FFI_ArrowArray::new(self)) }; + let jvm_schema = + unsafe { std::ptr::replace(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?) }; std::mem::forget(jvm_array); std::mem::forget(jvm_schema); From b07cf44894e40e1a1af50b66199722c035e38b43 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 Sep 2024 14:01:19 -0700 Subject: [PATCH 7/8] Try --- .../scala/org/apache/comet/vector/NativeUtil.scala | 11 +++++++++++ native/core/src/execution/utils.rs | 9 +++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 33af8662fe..b60db58b28 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -19,9 +19,13 @@ package org.apache.comet.vector +import java.nio.ByteOrder + import scala.collection.mutable import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data} +import org.apache.arrow.c.NativeUtil.NULL +import org.apache.arrow.memory.util.MemoryUtil import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.spark.SparkException @@ -70,6 +74,13 @@ class NativeUtil { (0 until numCols).foreach { index => val arrowSchema = ArrowSchema.allocateNew(allocator) + + // Manually fill NULL to `release` slot of ArrowSchema because ArrowSchema doesn't provide + // `markReleased`. + val buffer = + MemoryUtil.directBuffer(arrowSchema.memoryAddress(), 72).order(ByteOrder.nativeOrder) + buffer.putLong(56, NULL); + val arrowArray = ArrowArray.allocateNew(allocator) arrays(index) = arrowArray schemas(index) = arrowSchema diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index ca48c94d95..bd760bc3d7 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -108,18 +108,15 @@ impl SparkArrowConvert for ArrayData { let array_align = std::mem::align_of::(); let schema_align = std::mem::align_of::(); + // Check if the pointer alignment is correct for `replace`. if array_ptr.align_offset(array_align) != 0 || schema_ptr.align_offset(schema_align) != 0 { return Err(ExecutionError::ArrowError( "Pointer alignment is not correct".to_string(), )); } - let jvm_array = unsafe { std::ptr::replace(array_ptr, FFI_ArrowArray::new(self)) }; - let jvm_schema = - unsafe { std::ptr::replace(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?) }; - - std::mem::forget(jvm_array); - std::mem::forget(jvm_schema); + unsafe { std::ptr::replace(array_ptr, FFI_ArrowArray::new(self)) }; + unsafe { std::ptr::replace(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?) }; Ok(()) } From ad61ee9a9509f8ad04abb24e614dd3b936a9e351 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Sep 2024 12:54:26 -0700 Subject: [PATCH 8/8] Add comment --- common/src/main/scala/org/apache/comet/vector/NativeUtil.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index b60db58b28..4b113d89a8 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -77,6 +77,8 @@ class NativeUtil { // Manually fill NULL to `release` slot of ArrowSchema because ArrowSchema doesn't provide // `markReleased`. + // The total size of ArrowSchema is 72 bytes. + // The `release` slot is at offset 56 in the ArrowSchema struct. val buffer = MemoryUtil.directBuffer(arrowSchema.memoryAddress(), 72).order(ByteOrder.nativeOrder) buffer.putLong(56, NULL);