From 77421727662ad164fdb6312fc3f5d0cc2b006be3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 Aug 2025 11:54:59 -0600 Subject: [PATCH 01/13] deep copy all scans, remove Comet filter --- native/core/src/execution/operators/copy.rs | 2 +- native/core/src/execution/operators/filter.rs | 574 ------------------ native/core/src/execution/operators/mod.rs | 3 - native/core/src/execution/operators/scan.rs | 11 +- native/core/src/execution/planner.rs | 78 +-- native/proto/src/proto/operator.proto | 1 - .../apache/comet/serde/QueryPlanSerde.scala | 17 - 7 files changed, 27 insertions(+), 659 deletions(-) delete mode 100644 native/core/src/execution/operators/filter.rs diff --git a/native/core/src/execution/operators/copy.rs b/native/core/src/execution/operators/copy.rs index f1e87c2e05..950be36b67 100644 --- a/native/core/src/execution/operators/copy.rs +++ b/native/core/src/execution/operators/copy.rs @@ -245,7 +245,7 @@ impl RecordBatchStream for CopyStream { } /// Copy an Arrow Array -fn copy_array(array: &dyn Array) -> ArrayRef { +pub(crate) fn copy_array(array: &dyn Array) -> ArrayRef { let capacity = array.len(); let data = array.to_data(); diff --git a/native/core/src/execution/operators/filter.rs b/native/core/src/execution/operators/filter.rs deleted file mode 100644 index 272bbd4d89..0000000000 --- a/native/core/src/execution/operators/filter.rs +++ /dev/null @@ -1,574 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{ready, Context, Poll}; - -use datafusion::physical_plan::{ - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, - ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, -}; - -use arrow::array::{ - make_array, Array, ArrayRef, BooleanArray, MutableArrayData, RecordBatchOptions, -}; -use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, SchemaRef}; -use arrow::error::ArrowError; -use arrow::record_batch::RecordBatch; -use datafusion::common::cast::as_boolean_array; -use datafusion::common::stats::Precision; -use datafusion::common::{ - internal_err, plan_err, project_schema, DataFusionError, Result, ScalarValue, -}; -use datafusion::execution::TaskContext; -use datafusion::logical_expr::Operator; -use datafusion::physical_expr::equivalence::ProjectionMapping; -use datafusion::physical_expr::expressions::BinaryExpr; -use datafusion::physical_expr::intervals::utils::check_support; -use datafusion::physical_expr::utils::collect_columns; -use datafusion::physical_expr::{ - analyze, split_conjunction, AcrossPartitions, AnalysisContext, ConstExpr, ExprBoundaries, - PhysicalExpr, -}; -use datafusion::physical_plan::common::can_project; -use datafusion::physical_plan::execution_plan::CardinalityEffect; -use futures::stream::{Stream, StreamExt}; -use log::trace; - -/// This is a copy of DataFusion's FilterExec with one modification to ensure that input -/// batches are never passed through unchanged. The changes are between the comments -/// `BEGIN Comet change` and `END Comet change`. -#[derive(Debug, Clone)] -pub struct FilterExec { - /// The expression to filter on. This expression must evaluate to a boolean value. - predicate: Arc, - /// The input plan - input: Arc, - /// Execution metrics - metrics: ExecutionPlanMetricsSet, - /// Selectivity for statistics. 0 = no rows, 100 = all rows - default_selectivity: u8, - /// Properties equivalence properties, partitioning, etc. - cache: PlanProperties, - /// The projection indices of the columns in the output schema of join - projection: Option>, -} - -impl FilterExec { - /// Create a FilterExec on an input - pub fn try_new( - predicate: Arc, - input: Arc, - ) -> Result { - match predicate.data_type(input.schema().as_ref())? { - DataType::Boolean => { - let default_selectivity = 20; - let cache = - Self::compute_properties(&input, &predicate, default_selectivity, None)?; - Ok(Self { - predicate, - input: Arc::clone(&input), - metrics: ExecutionPlanMetricsSet::new(), - default_selectivity, - cache, - projection: None, - }) - } - other => { - plan_err!("Filter predicate must return BOOLEAN values, got {other:?}") - } - } - } - - pub fn with_default_selectivity( - mut self, - default_selectivity: u8, - ) -> Result { - if default_selectivity > 100 { - return plan_err!( - "Default filter selectivity value needs to be less than or equal to 100" - ); - } - self.default_selectivity = default_selectivity; - Ok(self) - } - - /// Return new instance of [FilterExec] with the given projection. - pub fn with_projection(&self, projection: Option>) -> Result { - // Check if the projection is valid - can_project(&self.schema(), projection.as_ref())?; - - let projection = match projection { - Some(projection) => match &self.projection { - Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), - None => Some(projection), - }, - None => None, - }; - - let cache = Self::compute_properties( - &self.input, - &self.predicate, - self.default_selectivity, - projection.as_ref(), - )?; - Ok(Self { - predicate: Arc::clone(&self.predicate), - input: Arc::clone(&self.input), - metrics: self.metrics.clone(), - default_selectivity: self.default_selectivity, - cache, - projection, - }) - } - - /// The expression to filter on. This expression must evaluate to a boolean value. - pub fn predicate(&self) -> &Arc { - &self.predicate - } - - /// The input plan - pub fn input(&self) -> &Arc { - &self.input - } - - /// The default selectivity - pub fn default_selectivity(&self) -> u8 { - self.default_selectivity - } - - /// Projection - pub fn projection(&self) -> Option<&Vec> { - self.projection.as_ref() - } - - /// Calculates `Statistics` for `FilterExec`, by applying selectivity (either default, or estimated) to input statistics. - fn statistics_helper( - input: &Arc, - predicate: &Arc, - default_selectivity: u8, - ) -> Result { - let input_stats = input.partition_statistics(None)?; - let schema = input.schema(); - if !check_support(predicate, &schema) { - let selectivity = default_selectivity as f64 / 100.0; - let mut stats = input_stats.to_inexact(); - stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity); - stats.total_byte_size = stats - .total_byte_size - .with_estimated_selectivity(selectivity); - return Ok(stats); - } - - let num_rows = input_stats.num_rows; - let total_byte_size = input_stats.total_byte_size; - let input_analysis_ctx = - AnalysisContext::try_from_statistics(&input.schema(), &input_stats.column_statistics)?; - - let analysis_ctx = analyze(predicate, input_analysis_ctx, &schema)?; - - // Estimate (inexact) selectivity of predicate - let selectivity = analysis_ctx.selectivity.unwrap_or(1.0); - let num_rows = num_rows.with_estimated_selectivity(selectivity); - let total_byte_size = total_byte_size.with_estimated_selectivity(selectivity); - - let column_statistics = - collect_new_statistics(&input_stats.column_statistics, analysis_ctx.boundaries); - Ok(Statistics { - num_rows, - total_byte_size, - column_statistics, - }) - } - - fn extend_constants( - input: &Arc, - predicate: &Arc, - ) -> Vec { - let mut res_constants = Vec::new(); - let input_eqs = input.equivalence_properties(); - - let conjunctions = split_conjunction(predicate); - for conjunction in conjunctions { - if let Some(binary) = conjunction.as_any().downcast_ref::() { - if binary.op() == &Operator::Eq { - // Filter evaluates to single value for all partitions - if input_eqs.is_expr_constant(binary.left()).is_some() { - let across = input_eqs - .is_expr_constant(binary.right()) - .unwrap_or_default(); - res_constants.push(ConstExpr::new(Arc::clone(binary.right()), across)); - } else if input_eqs.is_expr_constant(binary.right()).is_some() { - let across = input_eqs - .is_expr_constant(binary.left()) - .unwrap_or_default(); - res_constants.push(ConstExpr::new(Arc::clone(binary.left()), across)); - } - } - } - } - res_constants - } - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - input: &Arc, - predicate: &Arc, - default_selectivity: u8, - projection: Option<&Vec>, - ) -> Result { - // Combine the equal predicates with the input equivalence properties - // to construct the equivalence properties: - let stats = Self::statistics_helper(input, predicate, default_selectivity)?; - let mut eq_properties = input.equivalence_properties().clone(); - let (equal_pairs, _) = collect_columns_from_predicate(predicate); - for (lhs, rhs) in equal_pairs { - eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))? - } - // Add the columns that have only one viable value (singleton) after - // filtering to constants. - let constants = collect_columns(predicate) - .into_iter() - .filter(|column| stats.column_statistics[column.index()].is_singleton()) - .map(|column| { - let value = stats.column_statistics[column.index()] - .min_value - .get_value(); - let expr = Arc::new(column) as _; - ConstExpr::new(expr, AcrossPartitions::Uniform(value.cloned())) - }); - // This is for statistics - eq_properties.add_constants(constants)?; - // This is for logical constant (for example: a = '1', then a could be marked as a constant) - // to do: how to deal with a multiple situation to represent = (for example, c1 between 0 and 0) - eq_properties.add_constants(Self::extend_constants(input, predicate))?; - - let mut output_partitioning = input.output_partitioning().clone(); - // If contains projection, update the PlanProperties. - if let Some(projection) = projection { - let schema = eq_properties.schema(); - let projection_mapping = ProjectionMapping::from_indices(projection, schema)?; - let out_schema = project_schema(schema, Some(projection))?; - output_partitioning = output_partitioning.project(&projection_mapping, &eq_properties); - eq_properties = eq_properties.project(&projection_mapping, out_schema); - } - - Ok(PlanProperties::new( - eq_properties, - output_partitioning, - input.pipeline_behavior(), - input.boundedness(), - )) - } -} - -impl DisplayAs for FilterExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - let display_projections = if let Some(projection) = self.projection.as_ref() { - format!( - ", projection=[{}]", - projection - .iter() - .map(|index| format!( - "{}@{}", - self.input.schema().fields().get(*index).unwrap().name(), - index - )) - .collect::>() - .join(", ") - ) - } else { - "".to_string() - }; - write!( - f, - "CometFilterExec: {}{}", - self.predicate, display_projections - ) - } - DisplayFormatType::TreeRender => unimplemented!(), - } - } -} - -impl ExecutionPlan for FilterExec { - fn name(&self) -> &'static str { - "CometFilterExec" - } - - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - &self.cache - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - fn maintains_input_order(&self) -> Vec { - // Tell optimizer this operator doesn't reorder its input - vec![true] - } - - fn with_new_children( - self: Arc, - mut children: Vec>, - ) -> Result> { - FilterExec::try_new(Arc::clone(&self.predicate), children.swap_remove(0)) - .and_then(|e| { - let selectivity = e.default_selectivity(); - e.with_default_selectivity(selectivity) - }) - .and_then(|e| e.with_projection(self.projection().cloned())) - .map(|e| Arc::new(e) as _) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - trace!( - "Start FilterExec::execute for partition {} of context session_id {} and task_id {:?}", - partition, - context.session_id(), - context.task_id() - ); - let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - Ok(Box::pin(FilterExecStream { - schema: self.schema(), - predicate: Arc::clone(&self.predicate), - input: self.input.execute(partition, context)?, - baseline_metrics, - projection: self.projection.clone(), - })) - } - - fn metrics(&self) -> Option { - Some(self.metrics.clone_inner()) - } - - /// The output statistics of a filtering operation can be estimated if the - /// predicate's selectivity value can be determined for the incoming data. - fn statistics(&self) -> Result { - let stats = - Self::statistics_helper(&self.input, self.predicate(), self.default_selectivity)?; - Ok(stats.project(self.projection.as_ref())) - } - - fn cardinality_effect(&self) -> CardinalityEffect { - CardinalityEffect::LowerEqual - } -} - -/// This function ensures that all bounds in the `ExprBoundaries` vector are -/// converted to closed bounds. If a lower/upper bound is initially open, it -/// is adjusted by using the next/previous value for its data type to convert -/// it into a closed bound. -fn collect_new_statistics( - input_column_stats: &[ColumnStatistics], - analysis_boundaries: Vec, -) -> Vec { - analysis_boundaries - .into_iter() - .enumerate() - .map( - |( - idx, - ExprBoundaries { - interval, - distinct_count, - .. - }, - )| { - let Some(interval) = interval else { - // If the interval is `None`, we can say that there are no rows: - return ColumnStatistics { - null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Null), - min_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), - distinct_count: Precision::Exact(0), - }; - }; - let (lower, upper) = interval.into_bounds(); - let (min_value, max_value) = if lower.eq(&upper) { - (Precision::Exact(lower), Precision::Exact(upper)) - } else { - (Precision::Inexact(lower), Precision::Inexact(upper)) - }; - ColumnStatistics { - null_count: input_column_stats[idx].null_count.to_inexact(), - max_value, - min_value, - sum_value: Precision::Absent, - distinct_count: distinct_count.to_inexact(), - } - }, - ) - .collect() -} - -/// The FilterExec streams wraps the input iterator and applies the predicate expression to -/// determine which rows to include in its output batches -struct FilterExecStream { - /// Output schema after the projection - schema: SchemaRef, - /// The expression to filter on. This expression must evaluate to a boolean value. - predicate: Arc, - /// The input partition to filter. - input: SendableRecordBatchStream, - /// Runtime metrics recording - baseline_metrics: BaselineMetrics, - /// The projection indices of the columns in the input schema - projection: Option>, -} - -fn filter_and_project( - batch: &RecordBatch, - predicate: &Arc, - projection: Option<&Vec>, - output_schema: &SchemaRef, -) -> Result { - predicate - .evaluate(batch) - .and_then(|v| v.into_array(batch.num_rows())) - .and_then(|array| { - Ok(match (as_boolean_array(&array), projection) { - // Apply filter array to record batch - (Ok(filter_array), None) => comet_filter_record_batch(batch, filter_array)?, - (Ok(filter_array), Some(projection)) => { - let projected_columns = projection - .iter() - .map(|i| Arc::clone(batch.column(*i))) - .collect(); - let projected_batch = - RecordBatch::try_new(Arc::clone(output_schema), projected_columns)?; - comet_filter_record_batch(&projected_batch, filter_array)? - } - (Err(_), _) => { - return internal_err!("Cannot create filter_array from non-boolean predicates"); - } - }) - }) -} - -// BEGIN Comet changes -// `FilterExec` could modify input batch or return input batch without change. Instead of always -// adding `CopyExec` on top of it, we only copy input batch for the special case. -pub fn comet_filter_record_batch( - record_batch: &RecordBatch, - predicate: &BooleanArray, -) -> std::result::Result { - if predicate.true_count() == record_batch.num_rows() { - // special case where we just make an exact copy - let arrays: Vec = record_batch - .columns() - .iter() - .map(|array| { - let capacity = array.len(); - let data = array.to_data(); - let mut mutable = MutableArrayData::new(vec![&data], false, capacity); - mutable.extend(0, 0, capacity); - make_array(mutable.freeze()) - }) - .collect(); - let options = RecordBatchOptions::new().with_row_count(Some(record_batch.num_rows())); - RecordBatch::try_new_with_options(Arc::clone(&record_batch.schema()), arrays, &options) - } else { - filter_record_batch(record_batch, predicate) - } -} -// END Comet changes - -impl Stream for FilterExecStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let poll; - loop { - match ready!(self.input.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - let timer = self.baseline_metrics.elapsed_compute().timer(); - let filtered_batch = filter_and_project( - &batch, - &self.predicate, - self.projection.as_ref(), - &self.schema, - )?; - timer.done(); - // Skip entirely filtered batches - if filtered_batch.num_rows() == 0 { - continue; - } - poll = Poll::Ready(Some(Ok(filtered_batch))); - break; - } - value => { - poll = Poll::Ready(value); - break; - } - } - } - self.baseline_metrics.record_poll(poll) - } - - fn size_hint(&self) -> (usize, Option) { - // Same number of record batches - self.input.size_hint() - } -} - -impl RecordBatchStream for FilterExecStream { - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) - } -} - -/// Return the equals Column-Pairs and Non-equals Column-Pairs -fn collect_columns_from_predicate(predicate: &Arc) -> EqualAndNonEqual<'_> { - let mut eq_predicate_columns = Vec::::new(); - let mut ne_predicate_columns = Vec::::new(); - - let predicates = split_conjunction(predicate); - predicates.into_iter().for_each(|p| { - if let Some(binary) = p.as_any().downcast_ref::() { - match binary.op() { - Operator::Eq => eq_predicate_columns.push((binary.left(), binary.right())), - Operator::NotEq => ne_predicate_columns.push((binary.left(), binary.right())), - _ => {} - } - } - }); - - (eq_predicate_columns, ne_predicate_columns) -} - -/// Pair of `Arc`s -pub type PhysicalExprPairRef<'a> = (&'a Arc, &'a Arc); - -/// The equals Column-Pairs and Non-equals Column-Pairs in the Predicates -pub type EqualAndNonEqual<'a> = (Vec>, Vec>); diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 4e15e4341a..c8cfebd45e 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -22,14 +22,11 @@ use std::fmt::Debug; use jni::objects::GlobalRef; pub use copy::*; -pub use filter::comet_filter_record_batch; -pub use filter::FilterExec; pub use scan::*; mod copy; mod expand; pub use expand::ExpandExec; -mod filter; mod scan; /// Error returned during executing operators. diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index a842efaa30..6d252743bd 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::execution::operators::copy_array; use crate::{ errors::CometError, execution::{ @@ -276,8 +277,16 @@ impl ScanExec { let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?; // TODO: validate array input data + // array_data.validate_full()?; - inputs.push(make_array(array_data)); + let array = make_array(array_data); + + // we copy the array to that we don't have to worry about potential memory + // corruption issues later on if underlying buffers are reused or freed + // TODO optimize this so that we only do this for Parquet inputs! + let array = copy_array(&array); + + inputs.push(array); // Drop the Arcs to avoid memory leak unsafe { diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 8533dae21a..27aa2ffb85 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -18,7 +18,6 @@ //! Converts Spark physical plan to DataFusion physical plan use crate::execution::operators::CopyMode; -use crate::execution::operators::FilterExec as CometFilterExec; use crate::{ errors::ExpressionError, execution::{ @@ -93,7 +92,7 @@ use arrow::array::{ use arrow::buffer::BooleanBuffer; use datafusion::common::utils::SingleRowListArrayBuilder; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec; +use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::limit::GlobalLimitExec; use datafusion_comet_proto::spark_operator::SparkFilePartition; use datafusion_comet_proto::{ @@ -1178,25 +1177,17 @@ impl PhysicalPlanner { let predicate = self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?; - let filter: Arc = - match (filter.wrap_child_in_copy_exec, filter.use_datafusion_filter) { - (true, true) => Arc::new(DataFusionFilterExec::try_new( - predicate, - Self::wrap_in_copy_exec(Arc::clone(&child.native_plan)), - )?), - (true, false) => Arc::new(CometFilterExec::try_new( - predicate, - Self::wrap_in_copy_exec(Arc::clone(&child.native_plan)), - )?), - (false, true) => Arc::new(DataFusionFilterExec::try_new( - predicate, - Arc::clone(&child.native_plan), - )?), - (false, false) => Arc::new(CometFilterExec::try_new( - predicate, - Arc::clone(&child.native_plan), - )?), - }; + let filter: Arc = if filter.wrap_child_in_copy_exec { + Arc::new(FilterExec::try_new( + predicate, + Self::wrap_in_copy_exec(Arc::clone(&child.native_plan)), + )?) + } else { + Arc::new(FilterExec::try_new( + predicate, + Arc::clone(&child.native_plan), + )?) + }; Ok(( scans, @@ -1564,14 +1555,7 @@ impl PhysicalPlanner { // the data corruption. Note that we only need to copy the input batch // if the child operator is `ScanExec`, because other operators after `ScanExec` // will create new arrays for the output batch. - let input = if can_reuse_input_batch(&child.native_plan) { - Arc::new(CopyExec::new( - Arc::clone(&child.native_plan), - CopyMode::UnpackOrDeepCopy, - )) - } else { - Arc::clone(&child.native_plan) - }; + let input = Arc::clone(&child.native_plan); let expand = Arc::new(ExpandExec::new(projections, input, schema)); Ok(( scans, @@ -1898,14 +1882,9 @@ impl PhysicalPlanner { )) } - /// Wrap an ExecutionPlan in a CopyExec, which will unpack any dictionary-encoded arrays - /// and make a deep copy of other arrays if the plan re-uses batches. + /// Wrap an ExecutionPlan in a CopyExec, which will unpack any dictionary-encoded arrays. fn wrap_in_copy_exec(plan: Arc) -> Arc { - if can_reuse_input_batch(&plan) { - Arc::new(CopyExec::new(plan, CopyMode::UnpackOrDeepCopy)) - } else { - Arc::new(CopyExec::new(plan, CopyMode::UnpackOrClone)) - } + Arc::new(CopyExec::new(plan, CopyMode::UnpackOrClone)) } /// Create a DataFusion physical aggregate expression from Spark physical aggregate expression @@ -2606,30 +2585,6 @@ impl From for DataFusionError { } } -/// Returns true if given operator can return input array as output array without -/// modification. This is used to determine if we need to copy the input batch to avoid -/// data corruption from reusing the input batch. -fn can_reuse_input_batch(op: &Arc) -> bool { - if op.as_any().is::() { - // JVM side can return arrow buffers to the pool - // Also, native_comet scan reuses mutable buffers - true - } else if op.as_any().is::() { - let copy_exec = op.as_any().downcast_ref::().unwrap(); - copy_exec.mode() == &CopyMode::UnpackOrClone && can_reuse_input_batch(copy_exec.input()) - } else if op.as_any().is::() { - // CometFilterExec guarantees that all arrays have been copied - false - } else { - for child in op.children() { - if can_reuse_input_batch(child) { - return true; - } - } - false - } -} - /// Collects the indices of the columns in the input schema that are used in the expression /// and returns them as a pair of vectors, one for the left side and one for the right side. fn expr_to_columns( @@ -3104,7 +3059,6 @@ mod tests { children: vec![child_op], op_struct: Some(OpStruct::Filter(spark_operator::Filter { predicate: Some(expr), - use_datafusion_filter: false, wrap_child_in_copy_exec: false, })), } @@ -3118,7 +3072,7 @@ mod tests { let (_scans, filter_exec) = planner.create_plan(&op, &mut vec![], 1).unwrap(); - assert_eq!("CometFilterExec", filter_exec.native_plan.name()); + assert_eq!("FilterExec", filter_exec.native_plan.name()); assert_eq!(1, filter_exec.children.len()); assert_eq!(0, filter_exec.additional_native_plans.len()); } diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index ac34def810..5cb332ef03 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -109,7 +109,6 @@ message Projection { message Filter { spark.spark_expression.Expr predicate = 1; - bool use_datafusion_filter = 2; // Some expressions don't support dictionary arrays, so may need to wrap the child in a CopyExec bool wrap_child_in_copy_exec = 3; } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f0b4059b99..c646ceb690 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1788,22 +1788,6 @@ object QueryPlanSerde extends Logging with CometExprShim { val cond = exprToProto(condition, child.output) if (cond.isDefined && childOp.nonEmpty) { - // We need to determine whether to use DataFusion's FilterExec or Comet's - // FilterExec. The difference is that DataFusion's implementation will sometimes pass - // batches through whereas the Comet implementation guarantees that a copy is always - // made, which is critical when using `native_comet` scans due to buffer re-use - - // TODO this could be optimized more to stop walking the tree on hitting - // certain operators such as join or aggregate which will copy batches - def containsNativeCometScan(plan: SparkPlan): Boolean = { - plan match { - case w: CometScanWrapper => containsNativeCometScan(w.originalPlan) - case scan: CometScanExec => scan.scanImpl == CometConf.SCAN_NATIVE_COMET - case _: CometNativeScanExec => false - case _ => plan.children.exists(containsNativeCometScan) - } - } - // Some native expressions do not support operating on dictionary-encoded arrays, so // wrap the child in a CopyExec to unpack dictionaries first. def wrapChildInCopyExec(condition: Expression): Boolean = { @@ -1816,7 +1800,6 @@ object QueryPlanSerde extends Logging with CometExprShim { val filterBuilder = OperatorOuterClass.Filter .newBuilder() .setPredicate(cond.get) - .setUseDatafusionFilter(!containsNativeCometScan(op)) .setWrapChildInCopyExec(wrapChildInCopyExec(condition)) Some(builder.setFilter(filterBuilder).build()) } else { From 23eac5ea72505640e047637232d93ac16ca87505 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 Aug 2025 12:03:30 -0600 Subject: [PATCH 02/13] remove bench --- native/core/Cargo.toml | 4 -- native/core/benches/filter.rs | 112 ---------------------------------- 2 files changed, 116 deletions(-) delete mode 100644 native/core/benches/filter.rs diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 92b7da8e5d..5bd62a8903 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -119,7 +119,3 @@ harness = false [[bench]] name = "parquet_decode" harness = false - -[[bench]] -name = "filter" -harness = false diff --git a/native/core/benches/filter.rs b/native/core/benches/filter.rs deleted file mode 100644 index 82fa4aac66..0000000000 --- a/native/core/benches/filter.rs +++ /dev/null @@ -1,112 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License.use arrow::array::{ArrayRef, BooleanBuilder, Int32Builder, RecordBatch, StringBuilder}; - -use arrow::array::builder::{BooleanBuilder, Int32Builder, StringBuilder}; -use arrow::array::{ArrayRef, RecordBatch}; -use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, Field, Schema}; -use comet::execution::operators::comet_filter_record_batch; -use criterion::{criterion_group, criterion_main, Criterion}; -use std::hint::black_box; -use std::sync::Arc; -use std::time::Duration; - -fn criterion_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("filter"); - - let num_rows = 8192; - let num_int_cols = 4; - let num_string_cols = 4; - - let batch = create_record_batch(num_rows, num_int_cols, num_string_cols); - - // create some different predicates - let mut predicate_select_few = BooleanBuilder::with_capacity(num_rows); - let mut predicate_select_many = BooleanBuilder::with_capacity(num_rows); - let mut predicate_select_all = BooleanBuilder::with_capacity(num_rows); - for i in 0..num_rows { - predicate_select_few.append_value(i % 10 == 0); - predicate_select_many.append_value(i % 10 > 0); - predicate_select_all.append_value(true); - } - let predicate_select_few = predicate_select_few.finish(); - let predicate_select_many = predicate_select_many.finish(); - let predicate_select_all = predicate_select_all.finish(); - - // baseline uses Arrow's filter_record_batch method - group.bench_function("arrow_filter_record_batch - few rows selected", |b| { - b.iter(|| filter_record_batch(black_box(&batch), black_box(&predicate_select_few))) - }); - group.bench_function("arrow_filter_record_batch - many rows selected", |b| { - b.iter(|| filter_record_batch(black_box(&batch), black_box(&predicate_select_many))) - }); - group.bench_function("arrow_filter_record_batch - all rows selected", |b| { - b.iter(|| filter_record_batch(black_box(&batch), black_box(&predicate_select_all))) - }); - - group.bench_function("comet_filter_record_batch - few rows selected", |b| { - b.iter(|| comet_filter_record_batch(black_box(&batch), black_box(&predicate_select_few))) - }); - group.bench_function("comet_filter_record_batch - many rows selected", |b| { - b.iter(|| comet_filter_record_batch(black_box(&batch), black_box(&predicate_select_many))) - }); - group.bench_function("comet_filter_record_batch - all rows selected", |b| { - b.iter(|| comet_filter_record_batch(black_box(&batch), black_box(&predicate_select_all))) - }); - - group.finish(); -} - -fn create_record_batch(num_rows: usize, num_int_cols: i32, num_string_cols: i32) -> RecordBatch { - let mut int32_builder = Int32Builder::with_capacity(num_rows); - let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); - for i in 0..num_rows { - int32_builder.append_value(i as i32); - string_builder.append_value(format!("this is string #{i}")); - } - let int32_array = Arc::new(int32_builder.finish()); - let string_array = Arc::new(string_builder.finish()); - - let mut fields = vec![]; - let mut columns: Vec = vec![]; - let mut i = 0; - for _ in 0..num_int_cols { - fields.push(Field::new(format!("c{i}"), DataType::Int32, false)); - columns.push(int32_array.clone()); // note this is just copying a reference to the array - i += 1; - } - for _ in 0..num_string_cols { - fields.push(Field::new(format!("c{i}"), DataType::Utf8, false)); - columns.push(string_array.clone()); // note this is just copying a reference to the array - i += 1; - } - let schema = Schema::new(fields); - RecordBatch::try_new(Arc::new(schema), columns).unwrap() -} - -fn config() -> Criterion { - Criterion::default() - .measurement_time(Duration::from_millis(500)) - .warm_up_time(Duration::from_millis(500)) -} - -criterion_group! { - name = benches; - config = config(); - targets = criterion_benchmark -} -criterion_main!(benches); From fe05a53e1923d87b3370f8f1ddf58d9b349aefa8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 Aug 2025 12:05:53 -0600 Subject: [PATCH 03/13] apply sort merge join fix --- dev/benchmarks/comet-tpch.sh | 1 + dev/benchmarks/tpcbench.py | 3 +++ native/core/src/execution/planner.rs | 7 +++++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/dev/benchmarks/comet-tpch.sh b/dev/benchmarks/comet-tpch.sh index 9ab855f614..1634b80ad7 100755 --- a/dev/benchmarks/comet-tpch.sh +++ b/dev/benchmarks/comet-tpch.sh @@ -41,6 +41,7 @@ $SPARK_HOME/bin/spark-submit \ --conf spark.plugins=org.apache.spark.CometPlugin \ --conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \ --conf spark.comet.exec.replaceSortMergeJoin=true \ + --conf spark.comet.explain.native.enabled=true \ --conf spark.comet.cast.allowIncompatible=true \ tpcbench.py \ --name comet \ diff --git a/dev/benchmarks/tpcbench.py b/dev/benchmarks/tpcbench.py index 2a6b5708bc..18df6d68e1 100644 --- a/dev/benchmarks/tpcbench.py +++ b/dev/benchmarks/tpcbench.py @@ -62,6 +62,9 @@ def main(benchmark: str, data_path: str, query_path: str, iterations: int, outpu iter_start_time = time.time() for query in range(1, num_queries+1): + if query != 9: + continue + spark.sparkContext.setJobDescription(f"{benchmark} q{query}") # read text file diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 27aa2ffb85..71bb0c7486 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1587,9 +1587,12 @@ impl PhysicalPlanner { }) .collect(); + let left = Self::wrap_in_copy_exec(Arc::clone(&join_params.left.native_plan)); + let right = Self::wrap_in_copy_exec(Arc::clone(&join_params.right.native_plan)); + let join = Arc::new(SortMergeJoinExec::try_new( - Arc::clone(&join_params.left.native_plan), - Arc::clone(&join_params.right.native_plan), + Arc::clone(&left), + Arc::clone(&right), join_params.join_on, join_params.join_filter, join_params.join_type, From cb80055adc45342c92488155ae3ce0c72a224b34 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 Aug 2025 12:29:00 -0600 Subject: [PATCH 04/13] revert --- dev/benchmarks/comet-tpch.sh | 1 - dev/benchmarks/tpcbench.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/dev/benchmarks/comet-tpch.sh b/dev/benchmarks/comet-tpch.sh index 1634b80ad7..9ab855f614 100755 --- a/dev/benchmarks/comet-tpch.sh +++ b/dev/benchmarks/comet-tpch.sh @@ -41,7 +41,6 @@ $SPARK_HOME/bin/spark-submit \ --conf spark.plugins=org.apache.spark.CometPlugin \ --conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \ --conf spark.comet.exec.replaceSortMergeJoin=true \ - --conf spark.comet.explain.native.enabled=true \ --conf spark.comet.cast.allowIncompatible=true \ tpcbench.py \ --name comet \ diff --git a/dev/benchmarks/tpcbench.py b/dev/benchmarks/tpcbench.py index 18df6d68e1..2a6b5708bc 100644 --- a/dev/benchmarks/tpcbench.py +++ b/dev/benchmarks/tpcbench.py @@ -62,9 +62,6 @@ def main(benchmark: str, data_path: str, query_path: str, iterations: int, outpu iter_start_time = time.time() for query in range(1, num_queries+1): - if query != 9: - continue - spark.sparkContext.setJobDescription(f"{benchmark} q{query}") # read text file From 9cd118e3ef18a9957896d54cce74ba869f1a48ff Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 Aug 2025 13:08:36 -0600 Subject: [PATCH 05/13] optimize --- native/core/src/execution/operators/scan.rs | 25 +++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 6d252743bd..a4910459dd 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -64,6 +64,8 @@ pub struct ScanExec { pub input_source: Option>, /// A description of the input source for informational purposes pub input_source_description: String, + /// Whether to copy incoming arrays + copy_arrays: bool, /// The data types of columns of the input batch. Converted from Spark schema. pub data_types: Vec, /// Schema of first batch @@ -95,6 +97,14 @@ impl ScanExec { let arrow_ffi_time = MetricBuilder::new(&metrics_set).subset_time("arrow_ffi_time", 0); let jvm_fetch_time = MetricBuilder::new(&metrics_set).subset_time("jvm_fetch_time", 0); + let copy_arrays = input_source_description.contains("native_comet") + || input_source_description.contains("native_iceberg_compat"); + + println!( + "*** {} -> copy_arrays = {copy_arrays}", + input_source_description + ); + // Scan's schema is determined by the input batch, so we need to set it before execution. // Note that we determine if arrays are dictionary-encoded based on the // first batch. The array may be dictionary-encoded in some batches and not others, and @@ -109,6 +119,7 @@ impl ScanExec { data_types.len(), &jvm_fetch_time, &arrow_ffi_time, + copy_arrays, )?; timer.stop(); batch @@ -131,6 +142,7 @@ impl ScanExec { exec_context_id, input_source, input_source_description: input_source_description.to_string(), + copy_arrays, data_types, batch: Arc::new(Mutex::new(Some(first_batch))), cache, @@ -187,6 +199,7 @@ impl ScanExec { self.data_types.len(), &self.jvm_fetch_time, &self.arrow_ffi_time, + self.copy_arrays, )?; *current_batch = Some(next_batch); } @@ -203,6 +216,7 @@ impl ScanExec { num_cols: usize, jvm_fetch_time: &Time, arrow_ffi_time: &Time, + make_copy: bool, ) -> Result { if exec_context_id == TEST_EXEC_CONTEXT_ID { // This is a unit test. We don't need to call JNI. @@ -281,10 +295,13 @@ impl ScanExec { let array = make_array(array_data); - // we copy the array to that we don't have to worry about potential memory - // corruption issues later on if underlying buffers are reused or freed - // TODO optimize this so that we only do this for Parquet inputs! - let array = copy_array(&array); + let array = if make_copy { + // we copy the array to that we don't have to worry about potential memory + // corruption issues later on if underlying buffers are reused or freed + copy_array(&array) + } else { + array + }; inputs.push(array); From d4217426ebec4b11391b14a4f6b0a298abc13e6f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 Aug 2025 13:39:31 -0600 Subject: [PATCH 06/13] debug --- spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index cf11bdf590..dcedebd01b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -89,6 +89,7 @@ abstract class CometTestBase conf.set(CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key, "true") conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") conf.set(CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key, "true") + conf.set(CometConf.COMET_EXPLAIN_NATIVE_ENABLED.key, "true") conf } From fbbe200b48e9370f8422f1c617f665986b245833 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 Aug 2025 13:55:41 -0600 Subject: [PATCH 07/13] save --- native/core/src/execution/operators/scan.rs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index a4910459dd..c230310ab3 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -97,13 +97,19 @@ impl ScanExec { let arrow_ffi_time = MetricBuilder::new(&metrics_set).subset_time("arrow_ffi_time", 0); let jvm_fetch_time = MetricBuilder::new(&metrics_set).subset_time("jvm_fetch_time", 0); - let copy_arrays = input_source_description.contains("native_comet") - || input_source_description.contains("native_iceberg_compat"); - - println!( - "*** {} -> copy_arrays = {copy_arrays}", - input_source_description - ); + // TODO needs a more robust approach than looking at a text field + let copy_arrays = match input_source_description { + source if source.contains("native_comet") => true, + source if source.contains("native_iceberg_compat") => true, + source if source.contains("BroadcastQueryStage") => false, + source if source.contains("ShuffleQueryStage") => false, + _ => { + // take cautious approach for anything else because it could be backed + // by a Parquet scan + println!("ScanExec default to copying for {input_source_description}"); + true + } + }; // Scan's schema is determined by the input batch, so we need to set it before execution. // Note that we determine if arrays are dictionary-encoded based on the From c08e09ce8b2472a3a3eba56afe6e116924d1c04a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 Aug 2025 16:30:48 -0600 Subject: [PATCH 08/13] test --- native/core/src/execution/operators/scan.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index c230310ab3..e9c16215af 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -101,14 +101,7 @@ impl ScanExec { let copy_arrays = match input_source_description { source if source.contains("native_comet") => true, source if source.contains("native_iceberg_compat") => true, - source if source.contains("BroadcastQueryStage") => false, - source if source.contains("ShuffleQueryStage") => false, - _ => { - // take cautious approach for anything else because it could be backed - // by a Parquet scan - println!("ScanExec default to copying for {input_source_description}"); - true - } + _ => false, }; // Scan's schema is determined by the input batch, so we need to set it before execution. From 018f73b80ccd70c0fd8026ffdec01e8cb513462d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 15 Aug 2025 08:13:27 -0600 Subject: [PATCH 09/13] test --- native/core/src/execution/operators/scan.rs | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index e9c16215af..f4fe6cb63b 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -99,8 +99,24 @@ impl ScanExec { // TODO needs a more robust approach than looking at a text field let copy_arrays = match input_source_description { - source if source.contains("native_comet") => true, - source if source.contains("native_iceberg_compat") => true, + source if source.contains("native_comet") => + // mutable Parquet buffers get reused between batches + { + true + } + source if source.contains("native_iceberg_compat") => + // mutable Parquet buffers get reused between batches + { + true + } + source if source.contains("ShuffleWriterInput") => + // TODO need to understand why we get memory corruption if + // we don't do a deep copy here + { + true + } + source if source.contains("BroadcastQueryStage") => false, + source if source.contains("ShuffleQueryStage") => false, _ => false, }; From 2474451dc9e979041b6a8424237f3edadd4173fe Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 15 Aug 2025 08:24:19 -0600 Subject: [PATCH 10/13] repro --- native/core/src/execution/operators/scan.rs | 2 +- .../apache/comet/exec/CometJoinSuite.scala | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index f4fe6cb63b..e6332900c7 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -113,7 +113,7 @@ impl ScanExec { // TODO need to understand why we get memory corruption if // we don't do a deep copy here { - true + false } source if source.contains("BroadcastQueryStage") => false, source if source.contains("ShuffleQueryStage") => false, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index d47b4e0c1a..5b3d187e19 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -131,6 +131,27 @@ class CometJoinSuite extends CometTestBase { } } + test("repro for memory corruption") { + withSQLConf( + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + // Right join: build left + val df2 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + + df2.explain(true) + + checkSparkAnswerAndOperator(df2) + } + } + } + } + test("HashJoin without join filter") { withSQLConf( "spark.sql.join.forceApplyShuffledHashJoin" -> "true", From c5feb0b4da981eeba9409475680c2450e7f76d8a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 15 Aug 2025 09:33:39 -0600 Subject: [PATCH 11/13] save --- .../org/apache/comet/vector/NativeUtil.scala | 10 +- docs/source/contributor-guide/arrow-ffi.md | 185 ++++++++++++++++++ native/core/src/execution/jni_api.rs | 44 +++-- native/core/src/execution/operators/scan.rs | 7 +- .../org/apache/comet/CometBatchIterator.java | 1 + .../main/scala/org/apache/comet/Native.scala | 3 +- .../shuffle/NativeBatchDecoderIterator.scala | 7 +- .../apache/comet/exec/CometJoinSuite.scala | 1 + 8 files changed, 243 insertions(+), 15 deletions(-) create mode 100644 docs/source/contributor-guide/arrow-ffi.md diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index fba4e29e5e..2740cfe8fb 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -99,6 +99,9 @@ class NativeUtil { case a: CometVector => val valueVector = a.getValueVector + // scalastyle:off println + println("exporting array from jvm to native: " + valueVector) + numRows += valueVector.getValueCount val provider = if (valueVector.getField.getDictionary != null) { @@ -185,10 +188,15 @@ class NativeUtil { // Native execution should always have 'useDecimal128' set to true since it doesn't support // other cases. - arrayVectors += CometVector.getVector( + val cometVector = CometVector.getVector( importer.importVector(arrowArray, arrowSchema, dictionaryProvider), true, dictionaryProvider) + + // scalastyle:off + println(s"importVector $cometVector") + + arrayVectors += cometVector } arrayVectors.toSeq } diff --git a/docs/source/contributor-guide/arrow-ffi.md b/docs/source/contributor-guide/arrow-ffi.md new file mode 100644 index 0000000000..932965a666 --- /dev/null +++ b/docs/source/contributor-guide/arrow-ffi.md @@ -0,0 +1,185 @@ +# Arrow FFI Usage + +## Overview + +This document provides an overview of Comet's usage of Arrow FFI (Foreign Function Interface) to pass Arrow arrays +between the JVM (Java/Scala) and native (Rust) code. + +## Architecture Overview + +Arrow FFI is used extensively in Comet to transfer columnar data across language boundaries. + +- prepare_output is only place in native code to export native to jvm. used in executePlan and decodeShuffleBlock + + + +### Executing native plans and fetching results to JVM (CometExecIterator) + +### Exporting batches from JVM to native plan (CometBatchIterator) + +### Shuffle (decodeShuffleBlock reads from a Java DirectByteBuffer) + +- CometBlockStoreShuffleReader +- NativeBatchDecoderIterator +- native decodeShuffleBlock + +## Debugging FFI Issues + +TBD + +# [IGNORE] AI generated docs below - will incorporate into above documentation + +The main components +involved are: + +1. **JVM Side**: Scala/Java code that manages Arrow arrays and vectors +2. **Native Side**: Rust code that processes data using DataFusion +3. **Arrow C Data Interface**: Standard FFI structures (`FFI_ArrowArray` and `FFI_ArrowSchema`) + +## Key FFI Usage Patterns + +### 1. JVM to Native (Data Import) + +**Location**: `common/src/main/scala/org/apache/comet/vector/NativeUtil.scala` + +The JVM exports Arrow data to native code through the `exportBatch` method. This code is called via JNI from +native code. + +```scala +def exportBatch(arrayAddrs: Array[Long], schemaAddrs: Array[Long], batch: ColumnarBatch): Int = { + (0 until batch.numCols()).foreach { index => + val arrowSchema = ArrowSchema.wrap(schemaAddrs(index)) + val arrowArray = ArrowArray.wrap(arrayAddrs(index)) + Data.exportVector(allocator, getFieldVector(valueVector, "export"), provider, arrowArray, arrowSchema) + } +} +``` + +**Memory Management**: + +- `ArrowArray` and `ArrowSchema` structures are allocated by native side +- JVM uses `ArrowSchema.wrap()` and `ArrowArray.wrap()` to wrap existing pointers +- No deallocation needed on JVM side as structures are owned by native code + +### 2. Native to JVM (Data Export) + +**Location**: `native/core/src/execution/jni_api.rs` + +Native code exports data back to JVM through the `prepare_output` function: + +```rust +fn prepare_output(env: &mut JNIEnv, array_addrs: jlongArray, schema_addrs: jlongArray, + output_batch: RecordBatch, validate: bool) -> CometResult { + // Get memory addresses from JVM + let array_addrs = unsafe { env.get_array_elements(&array_address_array, ReleaseMode::NoCopyBack)? }; + let schema_addrs = unsafe { env.get_array_elements(&schema_address_array, ReleaseMode::NoCopyBack)? }; + + // Export each column + array_ref.to_data().move_to_spark(array_addrs[i], schema_addrs[i])?; +} +``` + +### 3. FFI Conversion Implementation + +**Location**: `native/core/src/execution/utils.rs` + +The core FFI conversion logic implements the `SparkArrowConvert` trait: + +```rust +impl SparkArrowConvert for ArrayData { + fn from_spark(addresses: (i64, i64)) -> Result { + let (array_ptr, schema_ptr) = addresses; + let array_ptr = array_ptr as *mut FFI_ArrowArray; + let schema_ptr = schema_ptr as *mut FFI_ArrowSchema; + + let mut ffi_array = unsafe { + let array_data = std::ptr::replace(array_ptr, FFI_ArrowArray::empty()); + let schema_data = std::ptr::replace(schema_ptr, FFI_ArrowSchema::empty()); + from_ffi(array_data, &schema_data)? + }; + + ffi_array.align_buffers(); + Ok(ffi_array) + } + + fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> { + let array_ptr = array as *mut FFI_ArrowArray; + let schema_ptr = schema as *mut FFI_ArrowSchema; + + unsafe { + std::ptr::write(array_ptr, FFI_ArrowArray::new(self)); + std::ptr::write(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?); + } + Ok(()) + } +} +``` + +## Memory Lifecycle Management + +### 1. Memory Allocation Strategy + +**ArrowArray and ArrowSchema Structures**: + +- **JVM Side**: Uses `ArrowArray.allocateNew(allocator)` and `ArrowSchema.allocateNew(allocator)` in `NativeUtil.allocateArrowStructs()` +- **Native Side**: Creates empty structures with `FFI_ArrowArray::empty()` and `FFI_ArrowSchema::empty()` + +**Buffer Memory**: + +- Arrow buffers are managed by Arrow's memory pool system +- Reference counting ensures proper cleanup +- Buffers can be shared between arrays through Arc<> wrappers + +### 2. Ownership Transfer Patterns + +**JVM to Native Transfer**: + +1. JVM allocates ArrowArray/ArrowSchema structures +2. JVM exports data using Arrow C Data Interface +3. Native code imports using `from_ffi()` which transfers ownership +4. Native code processes data and may modify it +5. JVM structures remain allocated until explicitly freed + +**Native to JVM Transfer**: + +1. Native code writes data to JVM-allocated structures using `move_to_spark()` +2. JVM wraps the structures and creates CometVectors +3. JVM takes ownership of the data buffers +4. Native structures can be safely dropped + +### 3. Memory Cleanup + +**JVM Cleanup**: + +- `NativeUtil.close()` closes dictionary provider which releases dictionary arrays +- Individual batches are closed via `ColumnarBatch.close()` in `CometExecIterator` +- Arrow allocator tracks memory usage but reports false positives due to FFI transfers + +**Native Cleanup**: + +- Rust's RAII automatically drops structures when they go out of scope +- `std::ptr::replace()` with empty structures ensures proper cleanup +- Explicit `Rc::from_raw()` calls in scan operations to avoid memory leaks + +## Memory Safety Risks + +**Location**: `spark/src/main/scala/org/apache/comet/CometExecIterator.scala` + +```scala +// Close previous batch if any. +// This is to guarantee safety at the native side before we overwrite the buffer memory +// shared across batches in the native side. +if (prevBatch != null) { + prevBatch.close() + prevBatch = null +} +``` + +**Risk**: The comment explicitly mentions "shared buffer memory across batches" but there's a window where: + +1. Native code might still have references to a batch +2. JVM closes the previous batch, potentially freeing buffers +3. Native code accesses freed memory + +**Mitigation**: In `planner.rs` we insert `CopyExec` operators to perform copies of arrays for operators +that may cache batches, but this is an area that we may improve in the future. diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 52ef184b19..6764b8493c 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -305,8 +305,16 @@ fn prepare_output( array_addrs: jlongArray, schema_addrs: jlongArray, output_batch: RecordBatch, - validate: bool, + stage_id: jint, + debug: bool, ) -> CometResult { + if debug { + println!( + "prepare_output stage_id={stage_id} (passing batch from native to JVM):\n{:?}", + output_batch + ); + } + let array_address_array = unsafe { JLongArray::from_raw(array_addrs) }; let num_cols = env.get_array_length(&array_address_array)? as usize; @@ -332,7 +340,7 @@ fn prepare_output( ))); } - if validate { + if debug { // Validate the output arrays. for array in results.iter() { let array_data = array.to_data(); @@ -359,13 +367,17 @@ fn prepare_output( assert_eq!(new_array.offset(), 0); - new_array - .to_data() - .move_to_spark(array_addrs[i], schema_addrs[i])?; + let data = new_array.to_data(); + if debug { + data.validate_full()?; + } + data.move_to_spark(array_addrs[i], schema_addrs[i])?; } else { - array_ref - .to_data() - .move_to_spark(array_addrs[i], schema_addrs[i])?; + let data = array_ref.to_data(); + if debug { + data.validate_full()?; + } + data.move_to_spark(array_addrs[i], schema_addrs[i])?; } i += 1; } @@ -447,7 +459,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( if exec_context.explain_native { let formatted_plan_str = DisplayableExecutionPlan::new(root_op.native_plan.as_ref()).indent(true); - info!("Comet native query plan:\n{formatted_plan_str:}"); + info!("Comet native query plan (Plan #{} Stage {} Partition {}):\n{formatted_plan_str:}", + exec_context.root_op.as_ref().unwrap().plan_id, stage_id, partition); } let task_ctx = exec_context.session_ctx.task_ctx(); @@ -485,6 +498,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( array_addrs, schema_addrs, output?, + stage_id, exec_context.debug_native, ); } @@ -528,6 +542,9 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( _class: JClass, exec_context: jlong, ) { + // TODO this removes references to the last batch from the plan - is this safe? + println!("releasePlan"); + try_unwrap_or_throw(&e, |mut env| unsafe { let execution_context = get_execution_context(exec_context); @@ -715,14 +732,19 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( array_addrs: jlongArray, schema_addrs: jlongArray, tracing_enabled: jboolean, + debug: bool ) -> jlong { try_unwrap_or_throw(&e, |mut env| { with_trace("decodeShuffleBlock", tracing_enabled != JNI_FALSE, || { - let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; + let raw_pointer: *mut u8 = env.get_direct_buffer_address(&byte_buffer)?; let length = length as usize; let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; let batch = read_ipc_compressed(slice)?; - prepare_output(&mut env, array_addrs, schema_addrs, batch, false) + if debug { + println!("decode shuffle block from JVM DirectByteBuffer @ {:?}", raw_pointer); + println!("decoded batch: {batch:?}"); + } + prepare_output(&mut env, array_addrs, schema_addrs, batch, 0, debug) }) }) } diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index e6332900c7..95630e0057 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -306,7 +306,7 @@ impl ScanExec { let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?; // TODO: validate array input data - // array_data.validate_full()?; + array_data.validate_full()?; let array = make_array(array_data); @@ -513,6 +513,11 @@ impl Stream for ScanStream<'_> { InputBatch::Batch(columns, num_rows) => { self.baseline_metrics.record_output(*num_rows); let maybe_batch = self.build_record_batch(columns, *num_rows); + + if let Ok(batch) = &maybe_batch { + println!("native got batch from jvm: {:?}", batch); + } + Poll::Ready(Some(maybe_batch)) } }; diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java b/spark/src/main/java/org/apache/comet/CometBatchIterator.java index e05bea1dff..82a0e21fd3 100644 --- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java +++ b/spark/src/main/java/org/apache/comet/CometBatchIterator.java @@ -69,6 +69,7 @@ public int next(long[] arrayAddrs, long[] schemaAddrs) { if (currentBatch == null) { return -1; } + int numRows = nativeUtil.exportBatch(arrayAddrs, schemaAddrs, currentBatch); currentBatch = null; return numRows; diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index afaa3d17db..d63e55707a 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -175,7 +175,8 @@ class Native extends NativeBase { length: Int, arrayAddrs: Array[Long], schemaAddrs: Array[Long], - tracingEnabled: Boolean): Long + tracingEnabled: Boolean, + debug: Boolean): Long /** * Log the beginning of an event. diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala index d461564f0a..0f29910b4e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala @@ -167,7 +167,8 @@ case class NativeBatchDecoderIterator( bytesToRead.toInt, arrayAddrs, schemaAddrs, - CometConf.COMET_TRACING_ENABLED.get()) + CometConf.COMET_TRACING_ENABLED.get(), + CometConf.COMET_DEBUG_ENABLED.get()) }) decodeTime.add(System.nanoTime() - startTime) @@ -175,6 +176,10 @@ case class NativeBatchDecoderIterator( } def close(): Unit = { + + // scalastyle:off println + println("Closing shuffle batch in JVM") + synchronized { if (!isClosed) { if (currentBatch != null) { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 5b3d187e19..7a2df9de12 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -134,6 +134,7 @@ class CometJoinSuite extends CometTestBase { test("repro for memory corruption") { withSQLConf( "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + CometConf.COMET_DEBUG_ENABLED.key -> "true", SQLConf.PREFER_SORTMERGEJOIN.key -> "false", SQLConf.SHUFFLE_PARTITIONS.key -> "2", SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", From 9f57233875d7e9772e63f3859f480816023452f2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 15 Aug 2025 11:30:40 -0600 Subject: [PATCH 12/13] save --- .../org/apache/comet/vector/CometVector.java | 3 ++ native/Cargo.lock | 1 + native/core/Cargo.toml | 1 + native/core/src/execution/jni_api.rs | 17 +++++++++-- native/core/src/execution/operators/scan.rs | 5 +++- native/core/src/execution/shuffle/codec.rs | 11 ++++++++ native/core/src/execution/shuffle/row.rs | 2 ++ native/core/src/execution/utils.rs | 1 + .../org/apache/comet/CometBatchIterator.java | 14 +++++++++- .../sort/CometShuffleExternalSorter.java | 1 + .../shuffle/CometDiskBlockWriter.java | 4 +++ .../comet/execution/shuffle/SpillWriter.java | 1 + .../org/apache/comet/CometExecIterator.scala | 5 ++++ .../shuffle/NativeBatchDecoderIterator.scala | 28 +++++++++++++++---- 14 files changed, 83 insertions(+), 11 deletions(-) diff --git a/common/src/main/java/org/apache/comet/vector/CometVector.java b/common/src/main/java/org/apache/comet/vector/CometVector.java index 6be8b28669..e1368e2ae0 100644 --- a/common/src/main/java/org/apache/comet/vector/CometVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometVector.java @@ -60,6 +60,7 @@ public abstract class CometVector extends ColumnVector { protected CometVector(DataType type, boolean useDecimal128) { super(type); + System.out.println("new CometVector"); this.useDecimal128 = useDecimal128; } @@ -207,6 +208,8 @@ public ColumnVector getChild(int i) { @Override public void close() { + System.out.println( + "[" + Thread.currentThread().getName() + "] CometVector.close() " + getValueVector()); getValueVector().close(); } diff --git a/native/Cargo.lock b/native/Cargo.lock index 4f4771b2d9..57466bf8d8 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1505,6 +1505,7 @@ dependencies = [ "itertools 0.14.0", "jni", "lazy_static", + "libc", "log", "log4rs", "lz4_flex", diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 5bd62a8903..138764a0ba 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -73,6 +73,7 @@ aws-config = { workspace = true } aws-credential-types = { workspace = true } parking_lot = "0.12.3" datafusion-comet-objectstore-hdfs = { path = "../hdfs", optional = true, default-features = false, features = ["hdfs"] } +libc = "0.2.175" [target.'cfg(target_os = "linux")'.dependencies] procfs = "0.17.0" diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 6764b8493c..0b5dc2c2ef 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -62,7 +62,7 @@ use jni::{ }; use std::path::PathBuf; use std::time::{Duration, Instant}; -use std::{sync::Arc, task::Poll}; +use std::{sync::Arc, task::Poll, thread}; use tokio::runtime::Runtime; use crate::execution::memory_pools::{ @@ -414,6 +414,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( array_addrs: jlongArray, schema_addrs: jlongArray, ) -> jlong { + use libc::pid_t; + let tid: pid_t = unsafe { libc::syscall(libc::SYS_gettid) as pid_t }; + println!("[{:?}] executePlan", tid); + try_unwrap_or_throw(&e, |mut env| { // Retrieve the query let exec_context = get_execution_context(exec_context); @@ -622,6 +626,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative compression_level: jint, tracing_enabled: jboolean, ) -> jlongArray { + println!("[{:?}] writeSortedFileNative", thread::current().id()); + try_unwrap_or_throw(&e, |mut env| unsafe { with_trace( "writeSortedFileNative", @@ -732,8 +738,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( array_addrs: jlongArray, schema_addrs: jlongArray, tracing_enabled: jboolean, - debug: bool + debug: bool, ) -> jlong { + println!("[{:?}] decodeShuffleBlock", thread::current().id()); + try_unwrap_or_throw(&e, |mut env| { with_trace("decodeShuffleBlock", tracing_enabled != JNI_FALSE, || { let raw_pointer: *mut u8 = env.get_direct_buffer_address(&byte_buffer)?; @@ -741,7 +749,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; let batch = read_ipc_compressed(slice)?; if debug { - println!("decode shuffle block from JVM DirectByteBuffer @ {:?}", raw_pointer); + println!( + "decode shuffle block from JVM DirectByteBuffer @ {:?}", + raw_pointer + ); println!("decoded batch: {batch:?}"); } prepare_output(&mut env, array_addrs, schema_addrs, batch, 0, debug) diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 95630e0057..52f893383c 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -49,6 +49,7 @@ use std::{ pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll}, + thread, }; /// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file @@ -515,7 +516,9 @@ impl Stream for ScanStream<'_> { let maybe_batch = self.build_record_batch(columns, *num_rows); if let Ok(batch) = &maybe_batch { - println!("native got batch from jvm: {:?}", batch); + use libc::pid_t; + let tid: pid_t = unsafe { libc::syscall(libc::SYS_gettid) as pid_t }; + println!("[{:?}] native got batch from jvm: {:?}", tid, batch); } Poll::Ready(Some(maybe_batch)) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 33e6989d4c..9b76fa0154 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -27,6 +27,7 @@ use datafusion::error::Result; use datafusion::physical_plan::metrics::Time; use simd_adler32::Adler32; use std::io::{Cursor, Seek, SeekFrom, Write}; +use std::thread; #[derive(Debug, Clone)] pub enum CompressionCodec { @@ -79,6 +80,16 @@ impl ShuffleBlockWriter { output: &mut W, ipc_time: &Time, ) -> Result { + use libc::pid_t; + let tid: pid_t = unsafe { libc::syscall(libc::SYS_gettid) as pid_t }; + + println!( + "[{:?}] writing shuffle batch: {:?}", + //thread::current().id(), + tid, + batch + ); + if batch.num_rows() == 0 { return Ok(0); } diff --git a/native/core/src/execution/shuffle/row.rs b/native/core/src/execution/shuffle/row.rs index e2f335e1b6..782047946c 100644 --- a/native/core/src/execution/shuffle/row.rs +++ b/native/core/src/execution/shuffle/row.rs @@ -3155,6 +3155,8 @@ pub fn process_sorted_row_partition( initial_checksum: Option, codec: &CompressionCodec, ) -> Result<(i64, Option), CometError> { + println!("process_sorted_row_partition"); + // TODO: We can tune this parameter automatically based on row size and cache size. let row_step = 10; diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 838c8523bb..a0bf19b524 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -91,6 +91,7 @@ impl SparkArrowConvert for ArrayData { // Check if the pointer alignment is correct. if array_ptr.align_offset(array_align) != 0 || schema_ptr.align_offset(schema_align) != 0 { + println!("write_unaligned!"); unsafe { std::ptr::write_unaligned(array_ptr, FFI_ArrowArray::new(self)); std::ptr::write_unaligned(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?); diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java b/spark/src/main/java/org/apache/comet/CometBatchIterator.java index 82a0e21fd3..c3bdd3ca2f 100644 --- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java +++ b/spark/src/main/java/org/apache/comet/CometBatchIterator.java @@ -48,6 +48,8 @@ public class CometBatchIterator { public int hasNext() { if (currentBatch == null) { if (input.hasNext()) { + + System.out.println("CometBatchIterator.hasNext() called"); currentBatch = input.next(); } } @@ -66,12 +68,22 @@ public int hasNext() { * @return the number of rows of the current batch. -1 if there is no more batch. */ public int next(long[] arrayAddrs, long[] schemaAddrs) { + System.out.println("CometBatchIterator.next() called"); if (currentBatch == null) { return -1; } - int numRows = nativeUtil.exportBatch(arrayAddrs, schemaAddrs, currentBatch); + System.out.println("Dropping reference to batch " + currentBatch); currentBatch = null; + + System.gc(); + + try { + Thread.sleep(1000L); + } catch (InterruptedException e) { + + } + return numRows; } } diff --git a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java index 8bc22b3426..de5ce7c17e 100644 --- a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java +++ b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java @@ -478,6 +478,7 @@ public long getMemoryUsage() { @Override protected void spill(int required) throws IOException { + System.out.println("SpillSorter Spilling " + required + " records"); CometShuffleExternalSorter.this.spill(); } diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java index adb228b17d..36806a47b8 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java @@ -244,6 +244,9 @@ public long getOutputRecords() { /** Serializes input row and inserts into current allocated page. */ public void insertRow(UnsafeRow row, int partitionId) throws IOException { + + System.out.println("CometDiskBlockWriter.insertRow"); + insertRecords++; if (!initialized) { @@ -430,6 +433,7 @@ long doSpilling(boolean isLast) throws IOException { */ @Override protected void spill(int required) throws IOException { + System.out.println("ArrowIPCWriter.spill(" + required + ")"); // Cannot allocate enough memory, spill and try again synchronized (currentWriters) { // Spill from the largest writer first to maximize the amount of memory we can diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java index 044c7842f0..c4f2c87ec4 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java @@ -135,6 +135,7 @@ public boolean acquireNewPageIfNecessary(int required) { // TODO: try to find space in previous pages try { currentPage = allocator.allocate(required); + System.out.println("Spilling to page " + currentPage); } catch (SparkOutOfMemoryError error) { try { // Cannot allocate enough memory, spill diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index c7504f3079..e1e7299fce 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -34,6 +34,8 @@ import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, COMET_ import org.apache.comet.Tracing.withTrace import org.apache.comet.vector.NativeUtil +// scalastyle:off + /** * An iterator class used to execute Comet native query. It takes an input iterator which comes * from Comet Scan and is expected to produce batches of Arrow Arrays. During consuming this @@ -196,6 +198,7 @@ class CometExecIterator( // This is to guarantee safety at the native side before we overwrite the buffer memory // shared across batches in the native side. if (prevBatch != null) { + println("[" + Thread.currentThread.getId + "] CometExecIterator closing batch") prevBatch.close() prevBatch = null } @@ -213,6 +216,7 @@ class CometExecIterator( override def next(): ColumnarBatch = { if (currentBatch != null) { // Eagerly release Arrow Arrays in the previous batch + println("[" + Thread.currentThread.getId + "] CometExecIterator closing batch") currentBatch.close() currentBatch = null } @@ -230,6 +234,7 @@ class CometExecIterator( def close(): Unit = synchronized { if (!closed) { if (currentBatch != null) { + println("[" + Thread.currentThread.getId + "] CometExecIterator closing batch") currentBatch.close() currentBatch = null } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala index 0f29910b4e..04d3cc8ab3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.comet.execution.shuffle import java.io.{EOFException, InputStream} -import java.nio.{ByteBuffer, ByteOrder} +import java.nio.{Buffer, ByteBuffer, ByteOrder} import java.nio.channels.{Channels, ReadableByteChannel} import org.apache.spark.TaskContext +import org.apache.spark.sql.comet.execution.shuffle.NativeBatchDecoderIterator.getByteBufferAddress import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.{CometConf, Native} import org.apache.comet.vector.NativeUtil +// scalastyle:off + /** * This iterator wraps a Spark input stream that is reading shuffle blocks generated by the Comet * native ShuffleWriterExec and then calls native code to decompress and decode the shuffle blocks @@ -72,6 +75,7 @@ case class NativeBatchDecoderIterator( // Release the previous batch. if (currentBatch != null) { + println(s"Closing shuffle batch in JVM $currentBatch") currentBatch.close() currentBatch = null } @@ -148,20 +152,26 @@ case class NativeBatchDecoderIterator( // require a 1GB compressed shuffle block but we check anyway val newCapacity = (bytesToRead * 2L).min(Integer.MAX_VALUE).toInt dataBuf = ByteBuffer.allocateDirect(newCapacity) + println(s"Allocated new ByteBuffer ${getByteBufferAddress(dataBuf)}") threadLocalDataBuf.set(dataBuf) } + println(s"BEGIN read into ByteBuffer ${getByteBufferAddress(dataBuf)}") dataBuf.clear() dataBuf.limit(bytesToRead.toInt) while (dataBuf.hasRemaining && channel.read(dataBuf) >= 0) {} if (dataBuf.hasRemaining) { throw new EOFException("Data corrupt: unexpected EOF while reading compressed batch") } + println(s"END read into ByteBuffer ${getByteBufferAddress(dataBuf)}") // make native call to decode batch val startTime = System.nanoTime() val batch = nativeUtil.getNextBatch( fieldCount, (arrayAddrs, schemaAddrs) => { + + println(s"call native decodeShuffleBlock for ByteBuffer ${getByteBufferAddress(dataBuf)}") + native.decodeShuffleBlock( dataBuf, bytesToRead.toInt, @@ -176,13 +186,11 @@ case class NativeBatchDecoderIterator( } def close(): Unit = { - - // scalastyle:off println - println("Closing shuffle batch in JVM") - synchronized { if (!isClosed) { if (currentBatch != null) { + // scalastyle:off println + println(s"Closing shuffle batch in JVM $currentBatch") currentBatch.close() currentBatch = null } @@ -195,6 +203,14 @@ case class NativeBatchDecoderIterator( object NativeBatchDecoderIterator { private val threadLocalDataBuf: ThreadLocal[ByteBuffer] = ThreadLocal.withInitial(() => { - ByteBuffer.allocateDirect(128 * 1024) + val buffer = ByteBuffer.allocateDirect(128 * 1024) + println(s"DirectByteBuffer native address: ${getByteBufferAddress(buffer)}") + buffer }) + + def getByteBufferAddress(buffer: ByteBuffer): String = { + val addressField = classOf[Buffer].getDeclaredField("address") + addressField.setAccessible(true) + "0x" + addressField.getLong(buffer).toHexString + } } From 6bbdf1b6ce1410f365548e247858e0bf6c769c5e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 15 Aug 2025 11:35:07 -0600 Subject: [PATCH 13/13] save --- .../src/main/java/org/apache/comet/vector/CometVector.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/src/main/java/org/apache/comet/vector/CometVector.java b/common/src/main/java/org/apache/comet/vector/CometVector.java index e1368e2ae0..e3cc00f602 100644 --- a/common/src/main/java/org/apache/comet/vector/CometVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometVector.java @@ -208,8 +208,9 @@ public ColumnVector getChild(int i) { @Override public void close() { - System.out.println( - "[" + Thread.currentThread().getName() + "] CometVector.close() " + getValueVector()); + String msg = "[" + Thread.currentThread().getName() + "] CometVector.close() " + getValueVector(); + System.out.println(msg); + new RuntimeException(msg).printStackTrace(); getValueVector().close(); }