Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 75 additions & 10 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@

package org.apache.comet.vector

import java.nio.ByteOrder

import scala.collection.mutable

import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data}
import org.apache.arrow.c.NativeUtil.NULL
import org.apache.arrow.memory.util.MemoryUtil
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.spark.SparkException
Expand Down Expand Up @@ -56,6 +60,37 @@ class NativeUtil {
*/
private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider

/**
* Allocates Arrow structs for the given number of columns.
*
* @param numCols
* the number of columns
* @return
* a pair of Arrow arrays and Arrow schemas
*/
def allocateArrowStructs(numCols: Int): (Array[ArrowArray], Array[ArrowSchema]) = {
val arrays = new Array[ArrowArray](numCols)
val schemas = new Array[ArrowSchema](numCols)

(0 until numCols).foreach { index =>
val arrowSchema = ArrowSchema.allocateNew(allocator)

// Manually fill NULL to `release` slot of ArrowSchema because ArrowSchema doesn't provide
// `markReleased`.
// The total size of ArrowSchema is 72 bytes.
// The `release` slot is at offset 56 in the ArrowSchema struct.
val buffer =
MemoryUtil.directBuffer(arrowSchema.memoryAddress(), 72).order(ByteOrder.nativeOrder)
buffer.putLong(56, NULL);
Comment thread
viirya marked this conversation as resolved.

val arrowArray = ArrowArray.allocateNew(allocator)
arrays(index) = arrowArray
schemas(index) = arrowSchema
}

(arrays, schemas)
}

/**
* Exports a Comet `ColumnarBatch` into a list of memory addresses that can be consumed by the
* native execution.
Expand Down Expand Up @@ -101,31 +136,61 @@ class NativeUtil {
batch.numRows()
}

/**
* Gets the next batch from native execution.
*
* @param numOutputCols
* The number of output columns
* @param func
* The function to call to get the next batch
* @return
* The number of row of the next batch, or None if there are no more batches
*/
def getNextBatch(
numOutputCols: Int,
func: (Array[Long], Array[Long]) => Long): Option[ColumnarBatch] = {
val (arrays, schemas) = allocateArrowStructs(numOutputCols)

val arrayAddrs = arrays.map(_.memoryAddress())
val schemaAddrs = schemas.map(_.memoryAddress())

val result = func(arrayAddrs, schemaAddrs)

result match {
case -1 =>
// EOF
None
case numRows =>
val cometVectors = importVector(arrays, schemas)
Some(new ColumnarBatch(cometVectors.toArray, numRows.toInt))
case flag =>
throw new IllegalStateException(s"Invalid native flag: $flag")
}
}

/**
* Imports a list of Arrow addresses from native execution, and return a list of Comet vectors.
*
* @param arrayAddress
* a list containing paris of Arrow addresses from the native, in the format of (address of
* Arrow array, address of Arrow schema)
* @param arrays
* a list of Arrow array
* @param schemas
* a list of Arrow schema
* @return
* a list of Comet vectors
*/
def importVector(arrayAddress: Array[Long]): Seq[CometVector] = {
def importVector(arrays: Array[ArrowArray], schemas: Array[ArrowSchema]): Seq[CometVector] = {
val arrayVectors = mutable.ArrayBuffer.empty[CometVector]

for (i <- arrayAddress.indices by 2) {
val arrowSchema = ArrowSchema.wrap(arrayAddress(i + 1))
val arrowArray = ArrowArray.wrap(arrayAddress(i))
(0 until arrays.length).foreach { i =>
val arrowSchema = schemas(i)
val arrowArray = arrays(i)

// Native execution should always have 'useDecimal128' set to true since it doesn't support
// other cases.
arrayVectors += CometVector.getVector(
importer.importVector(arrowArray, arrowSchema, dictionaryProvider),
true,
dictionaryProvider)

arrowArray.close()
arrowSchema.close()
}
arrayVectors.toSeq
}
Expand Down
71 changes: 37 additions & 34 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@

//! Define JNI APIs which can be called from Java/Scala.

use arrow::{
datatypes::DataType as ArrowDataType,
ffi::{FFI_ArrowArray, FFI_ArrowSchema},
};
use arrow::datatypes::DataType as ArrowDataType;
use arrow_array::RecordBatch;
use datafusion::{
execution::{
Expand Down Expand Up @@ -78,8 +75,6 @@ struct ExecutionContext {
pub input_sources: Vec<Arc<GlobalRef>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
/// The FFI arrays. We need to keep them alive here.
pub ffi_arrays: Vec<(Arc<FFI_ArrowArray>, Arc<FFI_ArrowSchema>)>,
/// Configurations for DF execution
pub conf: HashMap<String, String>,
/// The Tokio runtime used for async.
Expand Down Expand Up @@ -177,7 +172,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
scans: vec![],
input_sources,
stream: None,
ffi_arrays: vec![],
conf: configs,
runtime,
metrics,
Expand Down Expand Up @@ -265,14 +259,33 @@ fn parse_bool(conf: &HashMap<String, String>, name: &str) -> CometResult<bool> {
}

/// Prepares arrow arrays for output.
fn prepare_output(
unsafe fn prepare_output(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to narrow down the area of unsafe?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay

env: &mut JNIEnv,
array_addrs: jlongArray,
schema_addrs: jlongArray,
output_batch: RecordBatch,
exec_context: &mut ExecutionContext,
) -> CometResult<jlongArray> {
) -> CometResult<jlong> {
let array_address_array = JLongArray::from_raw(array_addrs);
let num_cols = env.get_array_length(&array_address_array)? as usize;

let array_addrs = env.get_array_elements(&array_address_array, ReleaseMode::NoCopyBack)?;
let array_addrs = &*array_addrs;

let schema_address_array = JLongArray::from_raw(schema_addrs);
let schema_addrs = env.get_array_elements(&schema_address_array, ReleaseMode::NoCopyBack)?;
let schema_addrs = &*schema_addrs;

let results = output_batch.columns();
let num_rows = output_batch.num_rows();

if results.len() != num_cols {
return Err(CometError::Internal(format!(
"Output column count mismatch: expected {num_cols}, got {}",
results.len()
)));
}

if exec_context.debug_native {
// Validate the output arrays.
for array in results.iter() {
Expand All @@ -283,35 +296,20 @@ fn prepare_output(
}
}

let return_flag = 1;

let long_array = env.new_long_array((results.len() * 2) as i32 + 2)?;
env.set_long_array_region(&long_array, 0, &[return_flag, num_rows as jlong])?;

let mut arrays = vec![];

let mut i = 0;
while i < results.len() {
let array_ref = results.get(i).ok_or(CometError::IndexOutOfBounds(i))?;
let (array, schema) = array_ref.to_data().to_spark()?;

unsafe {
let arrow_array = Arc::from_raw(array as *const FFI_ArrowArray);
let arrow_schema = Arc::from_raw(schema as *const FFI_ArrowSchema);
arrays.push((arrow_array, arrow_schema));
}
array_ref
.to_data()
.move_to_spark(array_addrs[i], schema_addrs[i])?;

env.set_long_array_region(&long_array, (i * 2) as i32 + 2, &[array, schema])?;
i += 1;
}

// Update metrics
update_metrics(env, exec_context)?;

// Record the pointer to allocated Arrow Arrays
exec_context.ffi_arrays = arrays;

Ok(long_array.into_raw())
Ok(num_rows as jlong)
}

/// Pull the next input from JVM. Note that we cannot pull input batches in
Expand All @@ -337,7 +335,9 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
e: JNIEnv,
_class: JClass,
exec_context: jlong,
) -> jlongArray {
array_addrs: jlongArray,
schema_addrs: jlongArray,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
// Retrieve the query
let exec_context = get_execution_context(exec_context);
Expand Down Expand Up @@ -383,7 +383,13 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(

match poll_output {
Poll::Ready(Some(output)) => {
return prepare_output(&mut env, output?, exec_context);
return prepare_output(
&mut env,
array_addrs,
schema_addrs,
output?,
exec_context,
);
}
Poll::Ready(None) => {
// Reaches EOF of output.
Expand All @@ -399,10 +405,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
}
}

let long_array = env.new_long_array(1)?;
env.set_long_array_region(&long_array, 0, &[-1])?;

return Ok(long_array.into_raw());
return Ok(-1);
}
// A poll pending means there are more than one blocking operators,
// we don't need go back-forth between JVM/Native. Just keeping polling.
Expand Down
24 changes: 24 additions & 0 deletions native/core/src/execution/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ pub trait SparkArrowConvert {
/// Convert Arrow Arrays to C data interface.
/// It returns a tuple (ArrowArray address, ArrowSchema address).
fn to_spark(&self) -> Result<(i64, i64), ExecutionError>;

/// Move Arrow Arrays to C data interface.
fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError>;
}

impl SparkArrowConvert for ArrayData {
Expand Down Expand Up @@ -96,6 +99,27 @@ impl SparkArrowConvert for ArrayData {

Ok((array as i64, schema as i64))
}

/// Move this ArrowData to pointers of Arrow C data interface.
fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> {
let array_ptr = array as *mut FFI_ArrowArray;
let schema_ptr = schema as *mut FFI_ArrowSchema;

let array_align = std::mem::align_of::<FFI_ArrowArray>();
let schema_align = std::mem::align_of::<FFI_ArrowSchema>();

// Check if the pointer alignment is correct for `replace`.
if array_ptr.align_offset(array_align) != 0 || schema_ptr.align_offset(schema_align) != 0 {
return Err(ExecutionError::ArrowError(
"Pointer alignment is not correct".to_string(),
));
}

unsafe { std::ptr::replace(array_ptr, FFI_ArrowArray::new(self)) };
unsafe { std::ptr::replace(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?) };

Ok(())
}
}

/// Converts a slice of bytes to i128. The bytes are serialized in big-endian order by
Expand Down
22 changes: 6 additions & 16 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.comet.vector.NativeUtil
class CometExecIterator(
val id: Long,
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
protobufQueryPlan: Array[Byte],
nativeMetrics: CometMetricNode)
extends Iterator[ColumnarBatch] {
Expand Down Expand Up @@ -100,22 +101,11 @@ class CometExecIterator(
}

def getNextBatch(): Option[ColumnarBatch] = {
// we execute the native plan each time we need another output batch and this could
// result in multiple input batches being processed
val result = nativeLib.executePlan(plan)

result(0) match {
case -1 =>
// EOF
None
case 1 =>
val numRows = result(1)
val addresses = result.slice(2, result.length)
val cometVectors = nativeUtil.importVector(addresses)
Some(new ColumnarBatch(cometVectors.toArray, numRows.toInt))
case flag =>
throw new IllegalStateException(s"Invalid native flag: $flag")
}
nativeUtil.getNextBatch(
numOutputCols,
(arrayAddrs, schemaAddrs) => {
nativeLib.executePlan(plan, arrayAddrs, schemaAddrs)
})
}

override def hasNext: Boolean = {
Expand Down
10 changes: 6 additions & 4 deletions spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ class Native extends NativeBase {
*
* @param plan
* the address to native query plan.
* @param arrayAddrs
* the addresses of Arrow Array structures
* @param schemaAddrs
* the addresses of Arrow Schema structures
* @return
* an array containing: 1) the status flag (1 for normal returned arrays, -1 for end of
* output) 2) (optional) the number of rows if returned flag is 1 3) the addresses of output
* Arrow arrays
* the number of rows, if -1, it means end of the output.
*/
@native def executePlan(plan: Long): Array[Long]
@native def executePlan(plan: Long, arrayAddrs: Array[Long], schemaAddrs: Array[Long]): Long

/**
* Release and drop the native query plan object and context object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object CometExecUtils {
limit: Int): RDD[ColumnarBatch] = {
childPlan.mapPartitionsInternal { iter =>
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get
CometExec.getCometIterator(Seq(iter), limitOp)
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ case class CometTakeOrderedAndProjectExec(
CometExecUtils
.getTopKNativePlan(child.output, sortOrder, child, limit)
.get
CometExec.getCometIterator(Seq(iter), topK)
CometExec.getCometIterator(Seq(iter), child.output.length, topK)
}
}

Expand All @@ -102,7 +102,7 @@ case class CometTakeOrderedAndProjectExec(
val topKAndProjection = CometExecUtils
.getProjectionNativePlan(projectList, child.output, sortOrder, child, limit)
.get
val it = CometExec.getCometIterator(Seq(iter), topKAndProjection)
val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection)
setSubqueries(it.id, this)

Option(TaskContext.get()).foreach { context =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ class CometShuffleWriteProcessor(

val cometIter = CometExec.getCometIterator(
Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
outputAttributes.length,
nativePlan,
nativeMetrics)

Expand Down
Loading