Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 0 additions & 44 deletions common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala

This file was deleted.

69 changes: 29 additions & 40 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,50 +47,39 @@ class NativeUtil {
* an exported batches object containing an array containing number of rows + pairs of memory
* addresses in the format of (address of Arrow array, address of Arrow schema)
*/
def exportBatch(batch: ColumnarBatch): ExportedBatch = {
val exportedVectors = mutable.ArrayBuffer.empty[Long]
exportedVectors += batch.numRows()

// Run checks prior to exporting the batch
(0 until batch.numCols()).foreach { index =>
val c = batch.column(index)
if (!c.isInstanceOf[CometVector]) {
batch.close()
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}

val arrowSchemas = mutable.ArrayBuffer.empty[ArrowSchema]
val arrowArrays = mutable.ArrayBuffer.empty[ArrowArray]

def exportBatch(
arrayAddrs: Array[Long],
schemaAddrs: Array[Long],
batch: ColumnarBatch): Int = {
(0 until batch.numCols()).foreach { index =>
val cometVector = batch.column(index).asInstanceOf[CometVector]
val valueVector = cometVector.getValueVector

val provider = if (valueVector.getField.getDictionary != null) {
cometVector.getDictionaryProvider
} else {
null
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector

val provider = if (valueVector.getField.getDictionary != null) {
a.getDictionaryProvider
} else {
null
}

// The array and schema structures are allocated by native side.
// Don't need to deallocate them here.
val arrowSchema = ArrowSchema.wrap(schemaAddrs(index))
val arrowArray = ArrowArray.wrap(arrayAddrs(index))
Data.exportVector(
allocator,
getFieldVector(valueVector, "export"),
provider,
arrowArray,
arrowSchema)
case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}

val arrowSchema = ArrowSchema.allocateNew(allocator)
val arrowArray = ArrowArray.allocateNew(allocator)
arrowSchemas += arrowSchema
arrowArrays += arrowArray
Data.exportVector(
allocator,
getFieldVector(valueVector, "export"),
provider,
arrowArray,
arrowSchema)

exportedVectors += arrowArray.memoryAddress()
exportedVectors += arrowSchema.memoryAddress()
}

ExportedBatch(exportedVectors.toArray, arrowSchemas.toArray, arrowArrays.toArray)
batch.numRows()
}

/**
Expand Down
100 changes: 58 additions & 42 deletions native/core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use futures::Stream;
use itertools::Itertools;
use std::rc::Rc;
use std::{
any::Any,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};

use futures::Stream;
use itertools::Itertools;

use arrow::compute::{cast_with_options, CastOptions};
use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions};
use arrow_data::ArrayData;
use arrow_schema::{DataType, Field, Schema, SchemaRef};

use crate::{
errors::CometError,
execution::{
Expand All @@ -38,17 +33,22 @@ use crate::{
},
jvm_bridge::{jni_call, JVMClasses},
};
use arrow::compute::{cast_with_options, CastOptions};
use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions};
use arrow_data::ffi::FFI_ArrowArray;
use arrow_data::ArrayData;
use arrow_schema::ffi::FFI_ArrowSchema;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::{
execution::TaskContext,
physical_expr::*,
physical_plan::{ExecutionPlan, *},
};
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult};
use jni::{
objects::{GlobalRef, JLongArray, JObject, ReleaseMode},
sys::jlongArray,
};
use jni::objects::JValueGen;
use jni::objects::{GlobalRef, JObject};
use jni::sys::jsize;

/// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file
/// scan or the result of reading a broadcast or shuffle exchange.
Expand Down Expand Up @@ -86,7 +86,7 @@ impl ScanExec {
// may end up either unpacking dictionary arrays or dictionary-encoding arrays.
// Dictionary-encoded primitive arrays are always unpacked.
let first_batch = if let Some(input_source) = input_source.as_ref() {
ScanExec::get_next(exec_context_id, input_source.as_obj())?
ScanExec::get_next(exec_context_id, input_source.as_obj(), data_types.len())?
} else {
InputBatch::EOF
};
Expand Down Expand Up @@ -153,6 +153,7 @@ impl ScanExec {
let next_batch = ScanExec::get_next(
self.exec_context_id,
self.input_source.as_ref().unwrap().as_obj(),
self.data_types.len(),
)?;
*current_batch = Some(next_batch);
}
Expand All @@ -161,7 +162,11 @@ impl ScanExec {
}

/// Invokes JNI call to get next batch.
fn get_next(exec_context_id: i64, iter: &JObject) -> Result<InputBatch, CometError> {
fn get_next(
exec_context_id: i64,
iter: &JObject,
num_cols: usize,
) -> Result<InputBatch, CometError> {
if exec_context_id == TEST_EXEC_CONTEXT_ID {
// This is a unit test. We don't need to call JNI.
return Ok(InputBatch::EOF);
Expand All @@ -175,49 +180,60 @@ impl ScanExec {
}

let mut env = JVMClasses::get_env()?;
let batch_object: JObject = unsafe {
jni_call!(&mut env,
comet_batch_iterator(iter).next() -> JObject)?
};

if batch_object.is_null() {
return Err(CometError::from(ExecutionError::GeneralError(format!(
"Null batch object. Plan id: {}",
exec_context_id
))));
let mut array_addrs = Vec::with_capacity(num_cols);
let mut schema_addrs = Vec::with_capacity(num_cols);

for _ in 0..num_cols {
let arrow_array = Rc::new(FFI_ArrowArray::empty());
let arrow_schema = Rc::new(FFI_ArrowSchema::empty());
let (array_ptr, schema_ptr) = (
Rc::into_raw(arrow_array) as i64,
Rc::into_raw(arrow_schema) as i64,
);

array_addrs.push(array_ptr);
schema_addrs.push(schema_ptr);
}

let batch_object = unsafe { JLongArray::from_raw(batch_object.as_raw() as jlongArray) };
// Prepare the java array parameters
let long_array_addrs = env.new_long_array(num_cols as jsize)?;
let long_schema_addrs = env.new_long_array(num_cols as jsize)?;

let addresses = unsafe { env.get_array_elements(&batch_object, ReleaseMode::NoCopyBack)? };
env.set_long_array_region(&long_array_addrs, 0, &array_addrs)?;
env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?;

// First element is the number of rows.
let num_rows = unsafe { *addresses.as_ptr() as i64 };
let array_obj = JObject::from(long_array_addrs);
let schema_obj = JObject::from(long_schema_addrs);

if num_rows < 0 {
return Ok(InputBatch::EOF);
}
let array_obj = JValueGen::Object(array_obj.as_ref());
let schema_obj = JValueGen::Object(schema_obj.as_ref());

let num_rows: i32 = unsafe {
jni_call!(&mut env,
comet_batch_iterator(iter).next(array_obj, schema_obj) -> i32)?
};

let array_num = addresses.len() - 1;
if array_num % 2 != 0 {
return Err(CometError::Internal(format!(
"Invalid number of Arrow Array addresses: {}",
array_num
)));
if num_rows == -1 {
return Ok(InputBatch::EOF);
}

let num_arrays = array_num / 2;
let array_elements = unsafe { addresses.as_ptr().add(1) };
let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_arrays);
let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_cols);

for i in 0..num_arrays {
let array_ptr = unsafe { *(array_elements.add(i * 2)) };
let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) };
for i in 0..num_cols {
let array_ptr = array_addrs[i];
let schema_ptr = schema_addrs[i];
let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?;

// TODO: validate array input data

inputs.push(make_array(array_data));

// Drop the Arcs to avoid memory leak
unsafe {
Rc::from_raw(array_ptr as *const FFI_ArrowArray);
Rc::from_raw(schema_ptr as *const FFI_ArrowSchema);
}
}

Ok(InputBatch::new(inputs, Some(num_rows as usize)))
Expand Down
5 changes: 3 additions & 2 deletions native/core/src/jvm_bridge/batch_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use jni::signature::Primitive;
use jni::{
errors::Result as JniResult,
objects::{JClass, JMethodID},
Expand All @@ -37,8 +38,8 @@ impl<'a> CometBatchIterator<'a> {

Ok(CometBatchIterator {
class,
method_next: env.get_method_id(Self::JVM_CLASS, "next", "()[J")?,
method_next_ret: ReturnType::Array,
method_next: env.get_method_id(Self::JVM_CLASS, "next", "([J[J)I")?,
method_next_ret: ReturnType::Primitive(Primitive::Int),
})
}
}
33 changes: 8 additions & 25 deletions spark/src/main/java/org/apache/comet/CometBatchIterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import org.apache.spark.sql.vectorized.ColumnarBatch;

import org.apache.comet.vector.ExportedBatch;
import org.apache.comet.vector.NativeUtil;

/**
Expand All @@ -35,41 +34,25 @@ public class CometBatchIterator {
final Iterator<ColumnarBatch> input;
final NativeUtil nativeUtil;

private ExportedBatch lastBatch;

CometBatchIterator(Iterator<ColumnarBatch> input, NativeUtil nativeUtil) {
this.input = input;
this.nativeUtil = nativeUtil;
this.lastBatch = null;
}

/**
* Get the next batches of Arrow arrays. It will consume input iterator and return Arrow arrays by
* addresses. If the input iterator is done, it will return a one negative element array
* indicating the end of the iterator.
* Get the next batches of Arrow arrays.
*
* @param arrayAddrs The addresses of the ArrowArray structures.
* @param schemaAddrs The addresses of the ArrowSchema structures.
* @return the number of rows of the current batch. -1 if there is no more batch.
*/
public long[] next() {
// Native side already copied the content of ArrowSchema and ArrowArray. We should deallocate
// the ArrowSchema and ArrowArray base structures allocated in JVM.
if (lastBatch != null) {
lastBatch.close();
lastBatch = null;
}

public int next(long[] arrayAddrs, long[] schemaAddrs) {
boolean hasBatch = input.hasNext();

if (!hasBatch) {
return new long[] {-1};
return -1;
}

lastBatch = nativeUtil.exportBatch(input.next());
return lastBatch.batch();
}

public void close() {
if (lastBatch != null) {
lastBatch.close();
lastBatch = null;
}
return nativeUtil.exportBatch(arrayAddrs, schemaAddrs, input.next());
}
}
2 changes: 0 additions & 2 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ class CometExecIterator(
}
nativeLib.releasePlan(plan)

cometBatchIterators.foreach(_.close())

// The allocator thoughts the exported ArrowArray and ArrowSchema structs are not released,
// so it will report:
// Caused by: java.lang.IllegalStateException: Memory was leaked by query.
Expand Down