diff --git a/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala b/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala deleted file mode 100644 index 2e97a0dcc6..0000000000 --- a/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.vector - -import org.apache.arrow.c.ArrowArray -import org.apache.arrow.c.ArrowSchema - -/** - * A wrapper class to hold the exported Arrow arrays and schemas. - * - * @param batch - * a list containing number of rows + pairs of memory addresses in the format of (address of - * Arrow array, address of Arrow schema) - * @param arrowSchemas - * the exported Arrow schemas, needs to be deallocated after being moved by the native executor - * @param arrowArrays - * the exported Arrow arrays, needs to be deallocated after being moved by the native executor - */ -case class ExportedBatch( - batch: Array[Long], - arrowSchemas: Array[ArrowSchema], - arrowArrays: Array[ArrowArray]) { - def close(): Unit = { - arrowSchemas.foreach(_.close()) - arrowArrays.foreach(_.close()) - } -} diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index eed8fd05b1..5149c73402 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -47,50 +47,39 @@ class NativeUtil { * an exported batches object containing an array containing number of rows + pairs of memory * addresses in the format of (address of Arrow array, address of Arrow schema) */ - def exportBatch(batch: ColumnarBatch): ExportedBatch = { - val exportedVectors = mutable.ArrayBuffer.empty[Long] - exportedVectors += batch.numRows() - - // Run checks prior to exporting the batch - (0 until batch.numCols()).foreach { index => - val c = batch.column(index) - if (!c.isInstanceOf[CometVector]) { - batch.close() - throw new SparkException( - "Comet execution only takes Arrow Arrays, but got " + - s"${c.getClass}") - } - } - - val arrowSchemas = mutable.ArrayBuffer.empty[ArrowSchema] - val arrowArrays = mutable.ArrayBuffer.empty[ArrowArray] - + def exportBatch( + arrayAddrs: Array[Long], + schemaAddrs: Array[Long], + batch: ColumnarBatch): Int = { (0 until batch.numCols()).foreach { index => - val cometVector = batch.column(index).asInstanceOf[CometVector] - val valueVector = cometVector.getValueVector - - val provider = if (valueVector.getField.getDictionary != null) { - cometVector.getDictionaryProvider - } else { - null + batch.column(index) match { + case a: CometVector => + val valueVector = a.getValueVector + + val provider = if (valueVector.getField.getDictionary != null) { + a.getDictionaryProvider + } else { + null + } + + // The array and schema structures are allocated by native side. + // Don't need to deallocate them here. + val arrowSchema = ArrowSchema.wrap(schemaAddrs(index)) + val arrowArray = ArrowArray.wrap(arrayAddrs(index)) + Data.exportVector( + allocator, + getFieldVector(valueVector, "export"), + provider, + arrowArray, + arrowSchema) + case c => + throw new SparkException( + "Comet execution only takes Arrow Arrays, but got " + + s"${c.getClass}") } - - val arrowSchema = ArrowSchema.allocateNew(allocator) - val arrowArray = ArrowArray.allocateNew(allocator) - arrowSchemas += arrowSchema - arrowArrays += arrowArray - Data.exportVector( - allocator, - getFieldVector(valueVector, "export"), - provider, - arrowArray, - arrowSchema) - - exportedVectors += arrowArray.memoryAddress() - exportedVectors += arrowSchema.memoryAddress() } - ExportedBatch(exportedVectors.toArray, arrowSchemas.toArray, arrowArrays.toArray) + batch.numRows() } /** diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 59616efbb2..0816a5c111 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use futures::Stream; +use itertools::Itertools; +use std::rc::Rc; use std::{ any::Any, pin::Pin, @@ -22,14 +25,6 @@ use std::{ task::{Context, Poll}, }; -use futures::Stream; -use itertools::Itertools; - -use arrow::compute::{cast_with_options, CastOptions}; -use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions}; -use arrow_data::ArrayData; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; - use crate::{ errors::CometError, execution::{ @@ -38,6 +33,12 @@ use crate::{ }, jvm_bridge::{jni_call, JVMClasses}, }; +use arrow::compute::{cast_with_options, CastOptions}; +use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow_data::ffi::FFI_ArrowArray; +use arrow_data::ArrayData; +use arrow_schema::ffi::FFI_ArrowSchema; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use datafusion::{ execution::TaskContext, @@ -45,10 +46,9 @@ use datafusion::{ physical_plan::{ExecutionPlan, *}, }; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; -use jni::{ - objects::{GlobalRef, JLongArray, JObject, ReleaseMode}, - sys::jlongArray, -}; +use jni::objects::JValueGen; +use jni::objects::{GlobalRef, JObject}; +use jni::sys::jsize; /// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file /// scan or the result of reading a broadcast or shuffle exchange. @@ -86,7 +86,7 @@ impl ScanExec { // may end up either unpacking dictionary arrays or dictionary-encoding arrays. // Dictionary-encoded primitive arrays are always unpacked. let first_batch = if let Some(input_source) = input_source.as_ref() { - ScanExec::get_next(exec_context_id, input_source.as_obj())? + ScanExec::get_next(exec_context_id, input_source.as_obj(), data_types.len())? } else { InputBatch::EOF }; @@ -153,6 +153,7 @@ impl ScanExec { let next_batch = ScanExec::get_next( self.exec_context_id, self.input_source.as_ref().unwrap().as_obj(), + self.data_types.len(), )?; *current_batch = Some(next_batch); } @@ -161,7 +162,11 @@ impl ScanExec { } /// Invokes JNI call to get next batch. - fn get_next(exec_context_id: i64, iter: &JObject) -> Result { + fn get_next( + exec_context_id: i64, + iter: &JObject, + num_cols: usize, + ) -> Result { if exec_context_id == TEST_EXEC_CONTEXT_ID { // This is a unit test. We don't need to call JNI. return Ok(InputBatch::EOF); @@ -175,49 +180,60 @@ impl ScanExec { } let mut env = JVMClasses::get_env()?; - let batch_object: JObject = unsafe { - jni_call!(&mut env, - comet_batch_iterator(iter).next() -> JObject)? - }; - if batch_object.is_null() { - return Err(CometError::from(ExecutionError::GeneralError(format!( - "Null batch object. Plan id: {}", - exec_context_id - )))); + let mut array_addrs = Vec::with_capacity(num_cols); + let mut schema_addrs = Vec::with_capacity(num_cols); + + for _ in 0..num_cols { + let arrow_array = Rc::new(FFI_ArrowArray::empty()); + let arrow_schema = Rc::new(FFI_ArrowSchema::empty()); + let (array_ptr, schema_ptr) = ( + Rc::into_raw(arrow_array) as i64, + Rc::into_raw(arrow_schema) as i64, + ); + + array_addrs.push(array_ptr); + schema_addrs.push(schema_ptr); } - let batch_object = unsafe { JLongArray::from_raw(batch_object.as_raw() as jlongArray) }; + // Prepare the java array parameters + let long_array_addrs = env.new_long_array(num_cols as jsize)?; + let long_schema_addrs = env.new_long_array(num_cols as jsize)?; - let addresses = unsafe { env.get_array_elements(&batch_object, ReleaseMode::NoCopyBack)? }; + env.set_long_array_region(&long_array_addrs, 0, &array_addrs)?; + env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?; - // First element is the number of rows. - let num_rows = unsafe { *addresses.as_ptr() as i64 }; + let array_obj = JObject::from(long_array_addrs); + let schema_obj = JObject::from(long_schema_addrs); - if num_rows < 0 { - return Ok(InputBatch::EOF); - } + let array_obj = JValueGen::Object(array_obj.as_ref()); + let schema_obj = JValueGen::Object(schema_obj.as_ref()); + + let num_rows: i32 = unsafe { + jni_call!(&mut env, + comet_batch_iterator(iter).next(array_obj, schema_obj) -> i32)? + }; - let array_num = addresses.len() - 1; - if array_num % 2 != 0 { - return Err(CometError::Internal(format!( - "Invalid number of Arrow Array addresses: {}", - array_num - ))); + if num_rows == -1 { + return Ok(InputBatch::EOF); } - let num_arrays = array_num / 2; - let array_elements = unsafe { addresses.as_ptr().add(1) }; - let mut inputs: Vec = Vec::with_capacity(num_arrays); + let mut inputs: Vec = Vec::with_capacity(num_cols); - for i in 0..num_arrays { - let array_ptr = unsafe { *(array_elements.add(i * 2)) }; - let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) }; + for i in 0..num_cols { + let array_ptr = array_addrs[i]; + let schema_ptr = schema_addrs[i]; let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?; // TODO: validate array input data inputs.push(make_array(array_data)); + + // Drop the Arcs to avoid memory leak + unsafe { + Rc::from_raw(array_ptr as *const FFI_ArrowArray); + Rc::from_raw(schema_ptr as *const FFI_ArrowSchema); + } } Ok(InputBatch::new(inputs, Some(num_rows as usize))) diff --git a/native/core/src/jvm_bridge/batch_iterator.rs b/native/core/src/jvm_bridge/batch_iterator.rs index 06f43a8ce4..4870624d2b 100644 --- a/native/core/src/jvm_bridge/batch_iterator.rs +++ b/native/core/src/jvm_bridge/batch_iterator.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use jni::signature::Primitive; use jni::{ errors::Result as JniResult, objects::{JClass, JMethodID}, @@ -37,8 +38,8 @@ impl<'a> CometBatchIterator<'a> { Ok(CometBatchIterator { class, - method_next: env.get_method_id(Self::JVM_CLASS, "next", "()[J")?, - method_next_ret: ReturnType::Array, + method_next: env.get_method_id(Self::JVM_CLASS, "next", "([J[J)I")?, + method_next_ret: ReturnType::Primitive(Primitive::Int), }) } } diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java b/spark/src/main/java/org/apache/comet/CometBatchIterator.java index eb7506b889..accd57c208 100644 --- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java +++ b/spark/src/main/java/org/apache/comet/CometBatchIterator.java @@ -23,7 +23,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; -import org.apache.comet.vector.ExportedBatch; import org.apache.comet.vector.NativeUtil; /** @@ -35,41 +34,25 @@ public class CometBatchIterator { final Iterator input; final NativeUtil nativeUtil; - private ExportedBatch lastBatch; - CometBatchIterator(Iterator input, NativeUtil nativeUtil) { this.input = input; this.nativeUtil = nativeUtil; - this.lastBatch = null; } /** - * Get the next batches of Arrow arrays. It will consume input iterator and return Arrow arrays by - * addresses. If the input iterator is done, it will return a one negative element array - * indicating the end of the iterator. + * Get the next batches of Arrow arrays. + * + * @param arrayAddrs The addresses of the ArrowArray structures. + * @param schemaAddrs The addresses of the ArrowSchema structures. + * @return the number of rows of the current batch. -1 if there is no more batch. */ - public long[] next() { - // Native side already copied the content of ArrowSchema and ArrowArray. We should deallocate - // the ArrowSchema and ArrowArray base structures allocated in JVM. - if (lastBatch != null) { - lastBatch.close(); - lastBatch = null; - } - + public int next(long[] arrayAddrs, long[] schemaAddrs) { boolean hasBatch = input.hasNext(); if (!hasBatch) { - return new long[] {-1}; + return -1; } - lastBatch = nativeUtil.exportBatch(input.next()); - return lastBatch.batch(); - } - - public void close() { - if (lastBatch != null) { - lastBatch.close(); - lastBatch = null; - } + return nativeUtil.exportBatch(arrayAddrs, schemaAddrs, input.next()); } } diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index f1e77fb5d1..29eb2f0ca9 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -159,8 +159,6 @@ class CometExecIterator( } nativeLib.releasePlan(plan) - cometBatchIterators.foreach(_.close()) - // The allocator thoughts the exported ArrowArray and ArrowSchema structs are not released, // so it will report: // Caused by: java.lang.IllegalStateException: Memory was leaked by query.