From 8f22ffe9b380d0745848654de0ae5a4869c1d416 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 23 Mar 2026 07:16:30 -0700 Subject: [PATCH 1/3] refactor: reorganize shuffle crate module structure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split and reorganize the shuffle crate for better cohesion: - Move `CompressionCodec` and `ShuffleBlockWriter` from `codec.rs` into `writers/shuffle_block_writer.rs`; inline the codec enum alongside its primary consumer - Move `Checksum` from `codec.rs` into `writers/checksum.rs` to keep all write-path types together - Rename `codec.rs` → `ipc.rs` (now only contains `read_ipc_compressed`) - Rename `writers/partition_writer.rs` → `writers/spill.rs` to better reflect its spill-management responsibility - Extract `SparkUnsafeObject` trait and `impl_primitive_accessors\!` macro from `spark_unsafe/row.rs` into `spark_unsafe/unsafe_object.rs` - Extract `ShufflePartitioner` trait from `partitioners/mod.rs` into `partitioners/traits.rs` - Add concise rustdoc comments to all structs, enums, and traits that were missing them --- .../src/execution/operators/shuffle_scan.rs | 6 +- native/shuffle/src/comet_partitioning.rs | 1 + native/shuffle/src/ipc.rs | 52 ++++ native/shuffle/src/lib.rs | 5 +- native/shuffle/src/metrics.rs | 1 + native/shuffle/src/partitioners/mod.rs | 13 +- .../src/partitioners/multi_partition.rs | 1 + .../partitioned_batch_iterator.rs | 1 + native/shuffle/src/partitioners/traits.rs | 27 +++ native/shuffle/src/spark_unsafe/list.rs | 7 +- native/shuffle/src/spark_unsafe/map.rs | 1 + native/shuffle/src/spark_unsafe/mod.rs | 1 + native/shuffle/src/spark_unsafe/row.rs | 214 +---------------- .../shuffle/src/spark_unsafe/unsafe_object.rs | 224 ++++++++++++++++++ .../shuffle/src/writers/buf_batch_writer.rs | 2 +- native/shuffle/src/writers/checksum.rs | 81 +++++++ native/shuffle/src/writers/mod.rs | 8 +- .../shuffle_block_writer.rs} | 97 +------- .../writers/{partition_writer.rs => spill.rs} | 4 +- 19 files changed, 420 insertions(+), 326 deletions(-) create mode 100644 native/shuffle/src/ipc.rs create mode 100644 native/shuffle/src/partitioners/traits.rs create mode 100644 native/shuffle/src/spark_unsafe/unsafe_object.rs create mode 100644 native/shuffle/src/writers/checksum.rs rename native/shuffle/src/{codec.rs => writers/shuffle_block_writer.rs} (60%) rename native/shuffle/src/writers/{partition_writer.rs => spill.rs} (95%) diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index 824965d489..c6f9123211 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -19,7 +19,7 @@ use crate::{ errors::CometError, execution::{ operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, - shuffle::codec::read_ipc_compressed, + shuffle::ipc::read_ipc_compressed, }, jvm_bridge::{jni_call, JVMClasses}, }; @@ -352,7 +352,7 @@ impl RecordBatchStream for ShuffleScanStream { #[cfg(test)] mod tests { - use crate::execution::shuffle::codec::{CompressionCodec, ShuffleBlockWriter}; + use crate::execution::shuffle::{CompressionCodec, ShuffleBlockWriter}; use arrow::array::{Int32Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -360,7 +360,7 @@ mod tests { use std::io::Cursor; use std::sync::Arc; - use crate::execution::shuffle::codec::read_ipc_compressed; + use crate::execution::shuffle::ipc::read_ipc_compressed; #[test] #[cfg_attr(miri, ignore)] // Miri cannot call FFI functions (zstd) diff --git a/native/shuffle/src/comet_partitioning.rs b/native/shuffle/src/comet_partitioning.rs index c269539a62..15912e6481 100644 --- a/native/shuffle/src/comet_partitioning.rs +++ b/native/shuffle/src/comet_partitioning.rs @@ -19,6 +19,7 @@ use arrow::row::{OwnedRow, RowConverter}; use datafusion::physical_expr::{LexOrdering, PhysicalExpr}; use std::sync::Arc; +/// Partitioning scheme for distributing rows across shuffle output partitions. #[derive(Debug, Clone)] pub enum CometPartitioning { SinglePartition, diff --git a/native/shuffle/src/ipc.rs b/native/shuffle/src/ipc.rs new file mode 100644 index 0000000000..81ee41332a --- /dev/null +++ b/native/shuffle/src/ipc.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::RecordBatch; +use arrow::ipc::reader::StreamReader; +use datafusion::common::DataFusionError; +use datafusion::error::Result; + +pub fn read_ipc_compressed(bytes: &[u8]) -> Result { + match &bytes[0..4] { + b"SNAP" => { + let decoder = snap::read::FrameDecoder::new(&bytes[4..]); + let mut reader = + unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; + reader.next().unwrap().map_err(|e| e.into()) + } + b"LZ4_" => { + let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); + let mut reader = + unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; + reader.next().unwrap().map_err(|e| e.into()) + } + b"ZSTD" => { + let decoder = zstd::Decoder::new(&bytes[4..])?; + let mut reader = + unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; + reader.next().unwrap().map_err(|e| e.into()) + } + b"NONE" => { + let mut reader = + unsafe { StreamReader::try_new(&bytes[4..], None)?.with_skip_validation(true) }; + reader.next().unwrap().map_err(|e| e.into()) + } + other => Err(DataFusionError::Execution(format!( + "Failed to decode batch: invalid compression codec: {other:?}" + ))), + } +} diff --git a/native/shuffle/src/lib.rs b/native/shuffle/src/lib.rs index 7c2fc8403f..f29588f2e1 100644 --- a/native/shuffle/src/lib.rs +++ b/native/shuffle/src/lib.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. -pub mod codec; pub(crate) mod comet_partitioning; +pub mod ipc; pub(crate) mod metrics; pub(crate) mod partitioners; mod shuffle_writer; pub mod spark_unsafe; pub(crate) mod writers; -pub use codec::{read_ipc_compressed, CompressionCodec, ShuffleBlockWriter}; pub use comet_partitioning::CometPartitioning; +pub use ipc::read_ipc_compressed; pub use shuffle_writer::ShuffleWriterExec; +pub use writers::{CompressionCodec, ShuffleBlockWriter}; diff --git a/native/shuffle/src/metrics.rs b/native/shuffle/src/metrics.rs index 1aba4677db..1de751cf41 100644 --- a/native/shuffle/src/metrics.rs +++ b/native/shuffle/src/metrics.rs @@ -19,6 +19,7 @@ use datafusion::physical_plan::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, Time, }; +/// Execution metrics for a shuffle partition operation. pub(crate) struct ShufflePartitionerMetrics { /// metrics pub(crate) baseline: BaselineMetrics, diff --git a/native/shuffle/src/partitioners/mod.rs b/native/shuffle/src/partitioners/mod.rs index a6d589677e..3eedef62c7 100644 --- a/native/shuffle/src/partitioners/mod.rs +++ b/native/shuffle/src/partitioners/mod.rs @@ -18,18 +18,9 @@ mod multi_partition; mod partitioned_batch_iterator; mod single_partition; - -use arrow::record_batch::RecordBatch; -use datafusion::common::Result; +mod traits; pub(crate) use multi_partition::MultiPartitionShuffleRepartitioner; pub(crate) use partitioned_batch_iterator::PartitionedBatchIterator; pub(crate) use single_partition::SinglePartitionShufflePartitioner; - -#[async_trait::async_trait] -pub(crate) trait ShufflePartitioner: Send + Sync { - /// Insert a batch into the partitioner - async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()>; - /// Write shuffle data and shuffle index file to disk - fn shuffle_write(&mut self) -> Result<()>; -} +pub(crate) use traits::ShufflePartitioner; diff --git a/native/shuffle/src/partitioners/multi_partition.rs b/native/shuffle/src/partitioners/multi_partition.rs index 42290c5510..655bee3511 100644 --- a/native/shuffle/src/partitioners/multi_partition.rs +++ b/native/shuffle/src/partitioners/multi_partition.rs @@ -39,6 +39,7 @@ use std::io::{BufReader, BufWriter, Seek, Write}; use std::sync::Arc; use tokio::time::Instant; +/// Reusable scratch buffers for computing row-to-partition assignments. #[derive(Default)] struct ScratchSpace { /// Hashes for each row in the current batch. diff --git a/native/shuffle/src/partitioners/partitioned_batch_iterator.rs b/native/shuffle/src/partitioners/partitioned_batch_iterator.rs index 77010938cd..8309a8ed4a 100644 --- a/native/shuffle/src/partitioners/partitioned_batch_iterator.rs +++ b/native/shuffle/src/partitioners/partitioned_batch_iterator.rs @@ -50,6 +50,7 @@ impl PartitionedBatchesProducer { } } +/// Iterates over the shuffled record batches belonging to a single output partition. pub(crate) struct PartitionedBatchIterator<'a> { record_batches: Vec<&'a RecordBatch>, batch_size: usize, diff --git a/native/shuffle/src/partitioners/traits.rs b/native/shuffle/src/partitioners/traits.rs new file mode 100644 index 0000000000..9572b70db5 --- /dev/null +++ b/native/shuffle/src/partitioners/traits.rs @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::record_batch::RecordBatch; +use datafusion::common::Result; + +#[async_trait::async_trait] +pub(crate) trait ShufflePartitioner: Send + Sync { + /// Insert a batch into the partitioner + async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()>; + /// Write shuffle data and shuffle index file to disk + fn shuffle_write(&mut self) -> Result<()>; +} diff --git a/native/shuffle/src/spark_unsafe/list.rs b/native/shuffle/src/spark_unsafe/list.rs index 4eb293895c..3fea3fadeb 100644 --- a/native/shuffle/src/spark_unsafe/list.rs +++ b/native/shuffle/src/spark_unsafe/list.rs @@ -17,10 +17,8 @@ use crate::spark_unsafe::{ map::append_map_elements, - row::{ - append_field, downcast_builder_ref, impl_primitive_accessors, SparkUnsafeObject, - SparkUnsafeRow, - }, + row::{append_field, downcast_builder_ref, SparkUnsafeRow}, + unsafe_object::{impl_primitive_accessors, SparkUnsafeObject}, }; use arrow::array::{ builder::{ @@ -86,6 +84,7 @@ macro_rules! impl_append_to_builder { }; } +/// A Spark `UnsafeArray` backed by JVM-allocated memory, providing element access by index. pub struct SparkUnsafeArray { row_addr: i64, num_elements: usize, diff --git a/native/shuffle/src/spark_unsafe/map.rs b/native/shuffle/src/spark_unsafe/map.rs index 57444cee7a..026e6f71d5 100644 --- a/native/shuffle/src/spark_unsafe/map.rs +++ b/native/shuffle/src/spark_unsafe/map.rs @@ -20,6 +20,7 @@ use arrow::array::builder::{ArrayBuilder, MapBuilder, MapFieldNames}; use arrow::datatypes::{DataType, FieldRef}; use datafusion_comet_jni_bridge::errors::CometError; +/// A Spark `UnsafeMap` backed by JVM-allocated memory, containing parallel keys and values arrays. pub struct SparkUnsafeMap { pub(crate) keys: SparkUnsafeArray, pub(crate) values: SparkUnsafeArray, diff --git a/native/shuffle/src/spark_unsafe/mod.rs b/native/shuffle/src/spark_unsafe/mod.rs index 6390a0f231..99a24410dd 100644 --- a/native/shuffle/src/spark_unsafe/mod.rs +++ b/native/shuffle/src/spark_unsafe/mod.rs @@ -18,3 +18,4 @@ pub mod list; mod map; pub mod row; +mod unsafe_object; diff --git a/native/shuffle/src/spark_unsafe/row.rs b/native/shuffle/src/spark_unsafe/row.rs index da980af8f9..3c98677199 100644 --- a/native/shuffle/src/spark_unsafe/row.rs +++ b/native/shuffle/src/spark_unsafe/row.rs @@ -17,11 +17,13 @@ //! Utils for supporting native sort-based columnar shuffle. -use crate::codec::{Checksum, ShuffleBlockWriter}; +use crate::spark_unsafe::unsafe_object::{impl_primitive_accessors, SparkUnsafeObject}; use crate::spark_unsafe::{ - list::{append_list_element, SparkUnsafeArray}, - map::{append_map_elements, get_map_key_value_fields, SparkUnsafeMap}, + list::append_list_element, + map::{append_map_elements, get_map_key_value_fields}, }; +use crate::writers::Checksum; +use crate::writers::ShuffleBlockWriter; use arrow::array::{ builder::{ ArrayBuilder, BinaryBuilder, BinaryDictionaryBuilder, BooleanBuilder, Date32Builder, @@ -36,219 +38,17 @@ use arrow::compute::cast; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use arrow::error::ArrowError; use datafusion::physical_plan::metrics::Time; -use datafusion_comet_common::bytes_to_i128; use datafusion_comet_jni_bridge::errors::CometError; use jni::sys::{jint, jlong}; use std::{ fs::OpenOptions, io::{Cursor, Write}, - str::from_utf8, sync::Arc, }; -const MAX_LONG_DIGITS: u8 = 18; const NESTED_TYPE_BUILDER_CAPACITY: usize = 100; -/// A common trait for Spark Unsafe classes that can be used to access the underlying data, -/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that can be used to -/// access the underlying data with index. -/// -/// # Safety -/// -/// Implementations must ensure that: -/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory -/// - `get_element_offset()` returns a valid pointer within the row/array data region -/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format -/// - The memory remains valid for the lifetime of the object (guaranteed by JVM ownership) -/// -/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer operations but are -/// safe to call as long as: -/// - The index is within bounds (caller's responsibility) -/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data -/// -/// # Alignment -/// -/// Primitive accessor methods are implemented separately for each type because they have -/// different alignment guarantees: -/// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is a multiple of 8, -/// and each field slot is 8 bytes), so accessors use aligned `ptr::read()`. -/// - `SparkUnsafeArray`: The array base address may be unaligned when nested within a row's -/// variable-length region, so accessors use `ptr::read_unaligned()`. -pub trait SparkUnsafeObject { - /// Returns the address of the row. - fn get_row_addr(&self) -> i64; - - /// Returns the offset of the element at the given index. - fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8; - - fn get_boolean(&self, index: usize) -> bool; - fn get_byte(&self, index: usize) -> i8; - fn get_short(&self, index: usize) -> i16; - fn get_int(&self, index: usize) -> i32; - fn get_long(&self, index: usize) -> i64; - fn get_float(&self, index: usize) -> f32; - fn get_double(&self, index: usize) -> f64; - fn get_date(&self, index: usize) -> i32; - fn get_timestamp(&self, index: usize) -> i64; - - /// Returns the offset and length of the element at the given index. - #[inline] - fn get_offset_and_len(&self, index: usize) -> (i32, i32) { - let offset_and_size = self.get_long(index); - let offset = (offset_and_size >> 32) as i32; - let len = offset_and_size as i32; - (offset, len) - } - - /// Returns string value at the given index of the object. - fn get_string(&self, index: usize) -> &str { - let (offset, len) = self.get_offset_and_len(index); - let addr = self.get_row_addr() + offset as i64; - // SAFETY: addr points to valid UTF-8 string data within the variable-length region. - // Offset and length are read from the fixed-length portion of the row/array. - debug_assert!(addr != 0, "get_string: null address at index {index}"); - debug_assert!( - len >= 0, - "get_string: negative length {len} at index {index}" - ); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }; - - from_utf8(slice).unwrap() - } - - /// Returns binary value at the given index of the object. - fn get_binary(&self, index: usize) -> &[u8] { - let (offset, len) = self.get_offset_and_len(index); - let addr = self.get_row_addr() + offset as i64; - // SAFETY: addr points to valid binary data within the variable-length region. - // Offset and length are read from the fixed-length portion of the row/array. - debug_assert!(addr != 0, "get_binary: null address at index {index}"); - debug_assert!( - len >= 0, - "get_binary: negative length {len} at index {index}" - ); - unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) } - } - - /// Returns decimal value at the given index of the object. - fn get_decimal(&self, index: usize, precision: u8) -> i128 { - if precision <= MAX_LONG_DIGITS { - self.get_long(index) as i128 - } else { - let slice = self.get_binary(index); - bytes_to_i128(slice) - } - } - - /// Returns struct value at the given index of the object. - fn get_struct(&self, index: usize, num_fields: usize) -> SparkUnsafeRow { - let (offset, len) = self.get_offset_and_len(index); - let mut row = SparkUnsafeRow::new_with_num_fields(num_fields); - row.point_to(self.get_row_addr() + offset as i64, len); - - row - } - - /// Returns array value at the given index of the object. - fn get_array(&self, index: usize) -> SparkUnsafeArray { - let (offset, _) = self.get_offset_and_len(index); - SparkUnsafeArray::new(self.get_row_addr() + offset as i64) - } - - fn get_map(&self, index: usize) -> SparkUnsafeMap { - let (offset, len) = self.get_offset_and_len(index); - SparkUnsafeMap::new(self.get_row_addr() + offset as i64, len) - } -} - -/// Generates primitive accessor implementations for `SparkUnsafeObject`. -/// -/// Uses `$read_method` to read typed values from raw pointers: -/// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte aligned) -/// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray) -macro_rules! impl_primitive_accessors { - ($read_method:ident) => { - #[inline] - fn get_boolean(&self, index: usize) -> bool { - let addr = self.get_element_offset(index, 1); - debug_assert!( - !addr.is_null(), - "get_boolean: null pointer at index {index}" - ); - // SAFETY: addr points to valid element data within the row/array region. - unsafe { *addr != 0 } - } - - #[inline] - fn get_byte(&self, index: usize) -> i8 { - let addr = self.get_element_offset(index, 1); - debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}"); - // SAFETY: addr points to valid element data (1 byte) within the row/array region. - unsafe { *(addr as *const i8) } - } - - #[inline] - fn get_short(&self, index: usize) -> i16 { - let addr = self.get_element_offset(index, 2) as *const i16; - debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}"); - // SAFETY: addr points to valid element data (2 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_int(&self, index: usize) -> i32 { - let addr = self.get_element_offset(index, 4) as *const i32; - debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}"); - // SAFETY: addr points to valid element data (4 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_long(&self, index: usize) -> i64 { - let addr = self.get_element_offset(index, 8) as *const i64; - debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}"); - // SAFETY: addr points to valid element data (8 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_float(&self, index: usize) -> f32 { - let addr = self.get_element_offset(index, 4) as *const f32; - debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}"); - // SAFETY: addr points to valid element data (4 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_double(&self, index: usize) -> f64 { - let addr = self.get_element_offset(index, 8) as *const f64; - debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}"); - // SAFETY: addr points to valid element data (8 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_date(&self, index: usize) -> i32 { - let addr = self.get_element_offset(index, 4) as *const i32; - debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}"); - // SAFETY: addr points to valid element data (4 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_timestamp(&self, index: usize) -> i64 { - let addr = self.get_element_offset(index, 8) as *const i64; - debug_assert!( - !addr.is_null(), - "get_timestamp: null pointer at index {index}" - ); - // SAFETY: addr points to valid element data (8 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - }; -} -pub(crate) use impl_primitive_accessors; - +/// A Spark `UnsafeRow` backed by JVM-allocated memory, providing field access by index. pub struct SparkUnsafeRow { row_addr: i64, row_size: i32, @@ -323,7 +123,7 @@ impl SparkUnsafeRow { } /// Points the row to the given address with specified row size. - fn point_to(&mut self, row_addr: i64, row_size: i32) { + pub(crate) fn point_to(&mut self, row_addr: i64, row_size: i32) { self.row_addr = row_addr; self.row_size = row_size; } diff --git a/native/shuffle/src/spark_unsafe/unsafe_object.rs b/native/shuffle/src/spark_unsafe/unsafe_object.rs new file mode 100644 index 0000000000..f32ea8c23b --- /dev/null +++ b/native/shuffle/src/spark_unsafe/unsafe_object.rs @@ -0,0 +1,224 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::list::SparkUnsafeArray; +use super::map::SparkUnsafeMap; +use super::row::SparkUnsafeRow; +use datafusion_comet_common::bytes_to_i128; +use std::str::from_utf8; + +const MAX_LONG_DIGITS: u8 = 18; + +/// A common trait for Spark Unsafe classes that can be used to access the underlying data, +/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that can be used to +/// access the underlying data with index. +/// +/// # Safety +/// +/// Implementations must ensure that: +/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory +/// - `get_element_offset()` returns a valid pointer within the row/array data region +/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format +/// - The memory remains valid for the lifetime of the object (guaranteed by JVM ownership) +/// +/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer operations but are +/// safe to call as long as: +/// - The index is within bounds (caller's responsibility) +/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data +/// +/// # Alignment +/// +/// Primitive accessor methods are implemented separately for each type because they have +/// different alignment guarantees: +/// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is a multiple of 8, +/// and each field slot is 8 bytes), so accessors use aligned `ptr::read()`. +/// - `SparkUnsafeArray`: The array base address may be unaligned when nested within a row's +/// variable-length region, so accessors use `ptr::read_unaligned()`. +pub trait SparkUnsafeObject { + /// Returns the address of the row. + fn get_row_addr(&self) -> i64; + + /// Returns the offset of the element at the given index. + fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8; + + fn get_boolean(&self, index: usize) -> bool; + fn get_byte(&self, index: usize) -> i8; + fn get_short(&self, index: usize) -> i16; + fn get_int(&self, index: usize) -> i32; + fn get_long(&self, index: usize) -> i64; + fn get_float(&self, index: usize) -> f32; + fn get_double(&self, index: usize) -> f64; + fn get_date(&self, index: usize) -> i32; + fn get_timestamp(&self, index: usize) -> i64; + + /// Returns the offset and length of the element at the given index. + #[inline] + fn get_offset_and_len(&self, index: usize) -> (i32, i32) { + let offset_and_size = self.get_long(index); + let offset = (offset_and_size >> 32) as i32; + let len = offset_and_size as i32; + (offset, len) + } + + /// Returns string value at the given index of the object. + fn get_string(&self, index: usize) -> &str { + let (offset, len) = self.get_offset_and_len(index); + let addr = self.get_row_addr() + offset as i64; + // SAFETY: addr points to valid UTF-8 string data within the variable-length region. + // Offset and length are read from the fixed-length portion of the row/array. + debug_assert!(addr != 0, "get_string: null address at index {index}"); + debug_assert!( + len >= 0, + "get_string: negative length {len} at index {index}" + ); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }; + + from_utf8(slice).unwrap() + } + + /// Returns binary value at the given index of the object. + fn get_binary(&self, index: usize) -> &[u8] { + let (offset, len) = self.get_offset_and_len(index); + let addr = self.get_row_addr() + offset as i64; + // SAFETY: addr points to valid binary data within the variable-length region. + // Offset and length are read from the fixed-length portion of the row/array. + debug_assert!(addr != 0, "get_binary: null address at index {index}"); + debug_assert!( + len >= 0, + "get_binary: negative length {len} at index {index}" + ); + unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) } + } + + /// Returns decimal value at the given index of the object. + fn get_decimal(&self, index: usize, precision: u8) -> i128 { + if precision <= MAX_LONG_DIGITS { + self.get_long(index) as i128 + } else { + let slice = self.get_binary(index); + bytes_to_i128(slice) + } + } + + /// Returns struct value at the given index of the object. + fn get_struct(&self, index: usize, num_fields: usize) -> SparkUnsafeRow { + let (offset, len) = self.get_offset_and_len(index); + let mut row = SparkUnsafeRow::new_with_num_fields(num_fields); + row.point_to(self.get_row_addr() + offset as i64, len); + + row + } + + /// Returns array value at the given index of the object. + fn get_array(&self, index: usize) -> SparkUnsafeArray { + let (offset, _) = self.get_offset_and_len(index); + SparkUnsafeArray::new(self.get_row_addr() + offset as i64) + } + + fn get_map(&self, index: usize) -> SparkUnsafeMap { + let (offset, len) = self.get_offset_and_len(index); + SparkUnsafeMap::new(self.get_row_addr() + offset as i64, len) + } +} + +/// Generates primitive accessor implementations for `SparkUnsafeObject`. +/// +/// Uses `$read_method` to read typed values from raw pointers: +/// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte aligned) +/// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray) +macro_rules! impl_primitive_accessors { + ($read_method:ident) => { + #[inline] + fn get_boolean(&self, index: usize) -> bool { + let addr = self.get_element_offset(index, 1); + debug_assert!( + !addr.is_null(), + "get_boolean: null pointer at index {index}" + ); + // SAFETY: addr points to valid element data within the row/array region. + unsafe { *addr != 0 } + } + + #[inline] + fn get_byte(&self, index: usize) -> i8 { + let addr = self.get_element_offset(index, 1); + debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}"); + // SAFETY: addr points to valid element data (1 byte) within the row/array region. + unsafe { *(addr as *const i8) } + } + + #[inline] + fn get_short(&self, index: usize) -> i16 { + let addr = self.get_element_offset(index, 2) as *const i16; + debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}"); + // SAFETY: addr points to valid element data (2 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_int(&self, index: usize) -> i32 { + let addr = self.get_element_offset(index, 4) as *const i32; + debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}"); + // SAFETY: addr points to valid element data (4 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_long(&self, index: usize) -> i64 { + let addr = self.get_element_offset(index, 8) as *const i64; + debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}"); + // SAFETY: addr points to valid element data (8 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_float(&self, index: usize) -> f32 { + let addr = self.get_element_offset(index, 4) as *const f32; + debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}"); + // SAFETY: addr points to valid element data (4 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_double(&self, index: usize) -> f64 { + let addr = self.get_element_offset(index, 8) as *const f64; + debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}"); + // SAFETY: addr points to valid element data (8 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_date(&self, index: usize) -> i32 { + let addr = self.get_element_offset(index, 4) as *const i32; + debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}"); + // SAFETY: addr points to valid element data (4 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_timestamp(&self, index: usize) -> i64 { + let addr = self.get_element_offset(index, 8) as *const i64; + debug_assert!( + !addr.is_null(), + "get_timestamp: null pointer at index {index}" + ); + // SAFETY: addr points to valid element data (8 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + }; +} +pub(crate) use impl_primitive_accessors; diff --git a/native/shuffle/src/writers/buf_batch_writer.rs b/native/shuffle/src/writers/buf_batch_writer.rs index 6344a8e5f2..cfddb46539 100644 --- a/native/shuffle/src/writers/buf_batch_writer.rs +++ b/native/shuffle/src/writers/buf_batch_writer.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::ShuffleBlockWriter; +use super::ShuffleBlockWriter; use arrow::array::RecordBatch; use arrow::compute::kernels::coalesce::BatchCoalescer; use datafusion::physical_plan::metrics::Time; diff --git a/native/shuffle/src/writers/checksum.rs b/native/shuffle/src/writers/checksum.rs new file mode 100644 index 0000000000..b240302e66 --- /dev/null +++ b/native/shuffle/src/writers/checksum.rs @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use bytes::Buf; +use crc32fast::Hasher; +use datafusion_comet_jni_bridge::errors::{CometError, CometResult}; +use simd_adler32::Adler32; +use std::io::{Cursor, SeekFrom}; + +/// Checksum algorithms for writing IPC bytes. +#[derive(Clone)] +pub(crate) enum Checksum { + /// CRC32 checksum algorithm. + CRC32(Hasher), + /// Adler32 checksum algorithm. + Adler32(Adler32), +} + +impl Checksum { + pub(crate) fn try_new(algo: i32, initial_opt: Option) -> CometResult { + match algo { + 0 => { + let hasher = if let Some(initial) = initial_opt { + Hasher::new_with_initial(initial) + } else { + Hasher::new() + }; + Ok(Checksum::CRC32(hasher)) + } + 1 => { + let hasher = if let Some(initial) = initial_opt { + // Note that Adler32 initial state is not zero. + // i.e., `Adler32::from_checksum(0)` is not the same as `Adler32::new()`. + Adler32::from_checksum(initial) + } else { + Adler32::new() + }; + Ok(Checksum::Adler32(hasher)) + } + _ => Err(CometError::Internal( + "Unsupported checksum algorithm".to_string(), + )), + } + } + + pub(crate) fn update(&mut self, cursor: &mut Cursor<&mut Vec>) -> CometResult<()> { + match self { + Checksum::CRC32(hasher) => { + std::io::Seek::seek(cursor, SeekFrom::Start(0))?; + hasher.update(cursor.chunk()); + Ok(()) + } + Checksum::Adler32(hasher) => { + std::io::Seek::seek(cursor, SeekFrom::Start(0))?; + hasher.write(cursor.chunk()); + Ok(()) + } + } + } + + pub(crate) fn finalize(self) -> u32 { + match self { + Checksum::CRC32(hasher) => hasher.finalize(), + Checksum::Adler32(hasher) => hasher.finish(), + } + } +} diff --git a/native/shuffle/src/writers/mod.rs b/native/shuffle/src/writers/mod.rs index b58989e46c..75caf9f3a3 100644 --- a/native/shuffle/src/writers/mod.rs +++ b/native/shuffle/src/writers/mod.rs @@ -16,7 +16,11 @@ // under the License. mod buf_batch_writer; -mod partition_writer; +mod checksum; +mod shuffle_block_writer; +mod spill; pub(crate) use buf_batch_writer::BufBatchWriter; -pub(crate) use partition_writer::PartitionWriter; +pub(crate) use checksum::Checksum; +pub use shuffle_block_writer::{CompressionCodec, ShuffleBlockWriter}; +pub(crate) use spill::PartitionWriter; diff --git a/native/shuffle/src/codec.rs b/native/shuffle/src/writers/shuffle_block_writer.rs similarity index 60% rename from native/shuffle/src/codec.rs rename to native/shuffle/src/writers/shuffle_block_writer.rs index c8edc2468c..5ed5330e3a 100644 --- a/native/shuffle/src/codec.rs +++ b/native/shuffle/src/writers/shuffle_block_writer.rs @@ -17,17 +17,13 @@ use arrow::array::RecordBatch; use arrow::datatypes::Schema; -use arrow::ipc::reader::StreamReader; use arrow::ipc::writer::StreamWriter; -use bytes::Buf; -use crc32fast::Hasher; use datafusion::common::DataFusionError; use datafusion::error::Result; use datafusion::physical_plan::metrics::Time; -use datafusion_comet_jni_bridge::errors::{CometError, CometResult}; -use simd_adler32::Adler32; use std::io::{Cursor, Seek, SeekFrom, Write}; +/// Compression algorithm applied to shuffle IPC blocks. #[derive(Debug, Clone)] pub enum CompressionCodec { None, @@ -36,6 +32,7 @@ pub enum CompressionCodec { Snappy, } +/// Writes a record batch as a length-prefixed, compressed Arrow IPC block. #[derive(Clone)] pub struct ShuffleBlockWriter { codec: CompressionCodec, @@ -147,93 +144,3 @@ impl ShuffleBlockWriter { Ok((end_pos - start_pos) as usize) } } - -pub fn read_ipc_compressed(bytes: &[u8]) -> Result { - match &bytes[0..4] { - b"SNAP" => { - let decoder = snap::read::FrameDecoder::new(&bytes[4..]); - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - b"LZ4_" => { - let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - b"ZSTD" => { - let decoder = zstd::Decoder::new(&bytes[4..])?; - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - b"NONE" => { - let mut reader = - unsafe { StreamReader::try_new(&bytes[4..], None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - other => Err(DataFusionError::Execution(format!( - "Failed to decode batch: invalid compression codec: {other:?}" - ))), - } -} - -/// Checksum algorithms for writing IPC bytes. -#[derive(Clone)] -pub(crate) enum Checksum { - /// CRC32 checksum algorithm. - CRC32(Hasher), - /// Adler32 checksum algorithm. - Adler32(Adler32), -} - -impl Checksum { - pub(crate) fn try_new(algo: i32, initial_opt: Option) -> CometResult { - match algo { - 0 => { - let hasher = if let Some(initial) = initial_opt { - Hasher::new_with_initial(initial) - } else { - Hasher::new() - }; - Ok(Checksum::CRC32(hasher)) - } - 1 => { - let hasher = if let Some(initial) = initial_opt { - // Note that Adler32 initial state is not zero. - // i.e., `Adler32::from_checksum(0)` is not the same as `Adler32::new()`. - Adler32::from_checksum(initial) - } else { - Adler32::new() - }; - Ok(Checksum::Adler32(hasher)) - } - _ => Err(CometError::Internal( - "Unsupported checksum algorithm".to_string(), - )), - } - } - - pub(crate) fn update(&mut self, cursor: &mut Cursor<&mut Vec>) -> CometResult<()> { - match self { - Checksum::CRC32(hasher) => { - std::io::Seek::seek(cursor, SeekFrom::Start(0))?; - hasher.update(cursor.chunk()); - Ok(()) - } - Checksum::Adler32(hasher) => { - std::io::Seek::seek(cursor, SeekFrom::Start(0))?; - hasher.write(cursor.chunk()); - Ok(()) - } - } - } - - pub(crate) fn finalize(self) -> u32 { - match self { - Checksum::CRC32(hasher) => hasher.finalize(), - Checksum::Adler32(hasher) => hasher.finish(), - } - } -} diff --git a/native/shuffle/src/writers/partition_writer.rs b/native/shuffle/src/writers/spill.rs similarity index 95% rename from native/shuffle/src/writers/partition_writer.rs rename to native/shuffle/src/writers/spill.rs index 48017871db..c16caddbf9 100644 --- a/native/shuffle/src/writers/partition_writer.rs +++ b/native/shuffle/src/writers/spill.rs @@ -15,20 +15,22 @@ // specific language governing permissions and limitations // under the License. +use super::ShuffleBlockWriter; use crate::metrics::ShufflePartitionerMetrics; use crate::partitioners::PartitionedBatchIterator; use crate::writers::buf_batch_writer::BufBatchWriter; -use crate::ShuffleBlockWriter; use datafusion::common::DataFusionError; use datafusion::execution::disk_manager::RefCountedTempFile; use datafusion::execution::runtime_env::RuntimeEnv; use std::fs::{File, OpenOptions}; +/// A temporary disk file for spilling a partition's intermediate shuffle data. struct SpillFile { temp_file: RefCountedTempFile, file: File, } +/// Manages encoding and optional disk spilling for a single shuffle partition. pub(crate) struct PartitionWriter { /// Spill file for intermediate shuffle output for this partition. Each spill event /// will append to this file and the contents will be copied to the shuffle file at From b9c0db179b4db07bcde11e9eafcaf042f1f1dec1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 23 Mar 2026 07:37:27 -0700 Subject: [PATCH 2/3] cargo fmt --- native/core/src/execution/operators/shuffle_scan.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index c6f9123211..a1ad52310c 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -18,8 +18,7 @@ use crate::{ errors::CometError, execution::{ - operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, - shuffle::ipc::read_ipc_compressed, + operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, shuffle::ipc::read_ipc_compressed, }, jvm_bridge::{jni_call, JVMClasses}, }; From 5cf70fcf242fccaf658de406a218efa9da32b8ce Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 23 Mar 2026 07:53:29 -0700 Subject: [PATCH 3/3] fix: make SparkUnsafeObject accessible from shuffle bench Make `unsafe_object` module public and update the bench import to use the correct path for `SparkUnsafeObject`. --- native/shuffle/benches/row_columnar.rs | 5 ++--- native/shuffle/src/spark_unsafe/mod.rs | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/native/shuffle/benches/row_columnar.rs b/native/shuffle/benches/row_columnar.rs index 7d3951b4d5..cc98f3faca 100644 --- a/native/shuffle/benches/row_columnar.rs +++ b/native/shuffle/benches/row_columnar.rs @@ -23,9 +23,8 @@ use arrow::datatypes::{DataType as ArrowDataType, Field, Fields}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use datafusion_comet_shuffle::spark_unsafe::row::{ - process_sorted_row_partition, SparkUnsafeObject, SparkUnsafeRow, -}; +use datafusion_comet_shuffle::spark_unsafe::row::{process_sorted_row_partition, SparkUnsafeRow}; +use datafusion_comet_shuffle::spark_unsafe::unsafe_object::SparkUnsafeObject; use datafusion_comet_shuffle::CompressionCodec; use std::sync::Arc; use tempfile::Builder; diff --git a/native/shuffle/src/spark_unsafe/mod.rs b/native/shuffle/src/spark_unsafe/mod.rs index 99a24410dd..abda69a087 100644 --- a/native/shuffle/src/spark_unsafe/mod.rs +++ b/native/shuffle/src/spark_unsafe/mod.rs @@ -18,4 +18,4 @@ pub mod list; mod map; pub mod row; -mod unsafe_object; +pub mod unsafe_object;