diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index 1ca41d4fe21ca..4ed7ba59b1238 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::cmp::Ordering; +use std::{cmp::Ordering, sync::Arc}; use arrow::buffer::ScalarBuffer; use arrow::compute::SortOptions; @@ -28,9 +28,10 @@ use datafusion_execution::memory_pool::MemoryReservation; /// A [`Cursor`] for [`Rows`] pub struct RowCursor { cur_row: usize, - num_rows: usize, + row_offset: usize, + row_limit: usize, // exclusive [offset..limit] - rows: Rows, + rows: Arc, /// Tracks for the memory used by in the `Rows` of this /// cursor. Freed on drop @@ -42,7 +43,7 @@ impl std::fmt::Debug for RowCursor { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.debug_struct("SortKeyCursor") .field("cur_row", &self.cur_row) - .field("num_rows", &self.num_rows) + .field("num_rows", &self.num_rows()) .finish() } } @@ -62,8 +63,9 @@ impl RowCursor { assert!(rows.num_rows() > 0); Self { cur_row: 0, - num_rows: rows.num_rows(), - rows, + row_offset: 0, + row_limit: rows.num_rows(), + rows: Arc::new(rows), reservation, } } @@ -104,19 +106,47 @@ pub trait Cursor: Ord { /// Advance the cursor, returning the previous row index fn advance(&mut self) -> usize; + + /// Slice the cursor at a given row index, returning a new cursor + /// + /// # Panics + /// + /// Panics if the slice is out of bounds, or memory is insufficient + fn slice(&self, offset: usize, length: usize) -> Self + where + Self: Sized; + + /// Returns the number of rows in this cursor + fn num_rows(&self) -> usize; } impl Cursor for RowCursor { #[inline] fn is_finished(&self) -> bool { - self.num_rows == self.cur_row + self.cur_row >= self.row_limit } #[inline] fn advance(&mut self) -> usize { let t = self.cur_row; self.cur_row += 1; - t + t - self.row_offset + } + + #[inline] + fn slice(&self, offset: usize, length: usize) -> Self { + Self { + cur_row: self.row_offset + offset, + row_offset: self.row_offset + offset, + row_limit: self.row_offset + offset + length, + rows: self.rows.clone(), + reservation: self.reservation.new_empty(), // Arc cloning of Rows is cheap + } + } + + #[inline] + fn num_rows(&self) -> usize { + self.row_limit - self.row_offset } } @@ -136,6 +166,10 @@ pub trait FieldValues { fn compare(a: &Self::Value, b: &Self::Value) -> Ordering; fn value(&self, idx: usize) -> &Self::Value; + + fn slice(&self, offset: usize, length: usize) -> Self + where + Self: Sized; } impl FieldArray for PrimitiveArray { @@ -165,6 +199,12 @@ impl FieldValues for PrimitiveValues { fn value(&self, idx: usize) -> &Self::Value { &self.0[idx] } + + #[inline] + fn slice(&self, offset: usize, length: usize) -> Self { + assert!(offset + length <= self.len(), "cursor slice out of bounds"); + Self(self.0.slice(offset, length)) + } } impl FieldArray for GenericByteArray { @@ -196,6 +236,15 @@ impl FieldValues for GenericByteArray { fn value(&self, idx: usize) -> &Self::Value { self.value(idx) } + + #[inline] + fn slice(&self, offset: usize, length: usize) -> Self { + assert!( + offset + length <= Array::len(self), + "cursor slice out of bounds" + ); + self.slice(offset, length) + } } /// A cursor over sorted, nullable [`FieldValues`] @@ -284,6 +333,34 @@ impl Cursor for FieldCursor { self.offset += 1; t } + + fn slice(&self, offset: usize, length: usize) -> Self { + let FieldCursor { + values, + offset: _, + null_threshold, + options, + } = self; + + let null_threshold = match self.options.nulls_first { + true => null_threshold.saturating_sub(offset), + false => { + let shorter_len = self.values.len().saturating_sub(offset + length + 1); + null_threshold.saturating_sub(offset.saturating_sub(shorter_len)) + } + }; + + Self { + values: values.slice(offset, length), + offset: 0, + null_threshold, + options: *options, + } + } + + fn num_rows(&self) -> usize { + self.values.len() + } } #[cfg(test)] @@ -308,6 +385,25 @@ mod tests { } } + #[test] + fn test_primitive_null_mask() { + let options = SortOptions { + descending: false, + nulls_first: true, + }; + + let is_min = new_primitive(options, ScalarBuffer::from(vec![i32::MIN]), 0); + assert_eq!(is_min.num_rows(), 1); + let is_null = new_primitive(options, ScalarBuffer::from(vec![i32::MIN]), 1); + assert_eq!(is_null.num_rows(), 1); + + // i32::MIN != NULL + assert_ne!(is_min.cmp(&is_null), Ordering::Equal); // is null mask + + assert!(is_null.is_null()); + assert!(!is_min.is_null()); + } + #[test] fn test_primitive_nulls_first() { let options = SortOptions { @@ -354,6 +450,11 @@ mod tests { a.advance(); assert_eq!(a.cmp(&b), Ordering::Less); + // finished + assert!(!b.is_finished()); + b.advance(); + assert!(b.is_finished()); + let options = SortOptions { descending: false, nulls_first: false, @@ -380,6 +481,12 @@ mod tests { assert_eq!(a.cmp(&b), Ordering::Equal); assert_eq!(a, b); + // finished + assert!(!a.is_finished()); + a.advance(); + a.advance(); + assert!(a.is_finished()); + let options = SortOptions { descending: true, nulls_first: false, @@ -441,4 +548,132 @@ mod tests { b.advance(); assert_eq!(a.cmp(&b), Ordering::Less); } + + #[test] + fn test_slice_primitive() { + let options = SortOptions { + descending: false, + nulls_first: true, + }; + + let buffer = ScalarBuffer::from(vec![0, 1, 2]); + let mut cursor = new_primitive(options, buffer, 0); + + // from start + let sliced = cursor.slice(0, 1); + assert_eq!(sliced.num_rows(), 1); + let expected = new_primitive(options, ScalarBuffer::from(vec![0]), 0); + assert_eq!( + sliced.cmp(&expected), + Ordering::Equal, + "should slice from start" + ); + + // with offset + let sliced = cursor.slice(1, 2); + assert_eq!(sliced.num_rows(), 2); + let expected = new_primitive(options, ScalarBuffer::from(vec![1]), 0); + assert_eq!( + sliced.cmp(&expected), + Ordering::Equal, + "should slice with offset" + ); + + // cursor current position != start + cursor.advance(); + let sliced = cursor.slice(0, 1); + assert_eq!(sliced.num_rows(), 1); + let expected = new_primitive(options, ScalarBuffer::from(vec![0]), 0); + assert_eq!( + sliced.cmp(&expected), + Ordering::Equal, + "should ignore current cursor position when sliced" + ); + } + + #[test] + #[should_panic(expected = "cursor slice out of bounds")] + fn test_slice_panic_can_panic() { + let options = SortOptions { + descending: false, + nulls_first: true, + }; + + let buffer = ScalarBuffer::from(vec![0, 1, 2]); + let cursor = new_primitive(options, buffer, 0); + + cursor.slice(42, 1); + } + + #[test] + fn test_slice_nulls_first() { + let options = SortOptions { + descending: false, + nulls_first: true, + }; + + let is_min = new_primitive(options, ScalarBuffer::from(vec![i32::MIN]), 0); + + let buffer = ScalarBuffer::from(vec![i32::MIN, 79, 2, i32::MIN]); + let mut a = new_primitive(options, buffer, 2); + assert_eq!(a.num_rows(), 4); + let buffer = ScalarBuffer::from(vec![i32::MIN, -284, 3, i32::MIN, 2]); + let mut b = new_primitive(options, buffer, 2); + assert_eq!(b.num_rows(), 5); + + // NULL == NULL + assert!(a.is_null()); + assert_eq!(a.cmp(&b), Ordering::Equal); + + // i32::MIN > NULL + a = a.slice(3, 1); + assert_eq!(a, is_min); + assert_eq!(a.cmp(&b), Ordering::Greater); + + // i32::MIN == i32::MIN + b = b.slice(3, 2); + assert_eq!(b, is_min); + assert_eq!(a.cmp(&b), Ordering::Equal); + + // i32::MIN < 2 + b = b.slice(1, 1); + assert_eq!(a.cmp(&b), Ordering::Less); + } + + #[test] + fn test_slice_nulls_last() { + let options = SortOptions { + descending: false, + nulls_first: false, + }; + + let is_min = new_primitive(options, ScalarBuffer::from(vec![i32::MIN]), 0); + + let buffer = ScalarBuffer::from(vec![i32::MIN, 79, 2, i32::MIN]); + let mut a = new_primitive(options, buffer, 2); + assert_eq!(a.num_rows(), 4); + let buffer = ScalarBuffer::from(vec![i32::MIN, -284, 3, i32::MIN, 2]); + let mut b = new_primitive(options, buffer, 2); + assert_eq!(b.num_rows(), 5); + + // i32::MIN == i32::MIN + assert_eq!(a, is_min); + assert_eq!(a.cmp(&b), Ordering::Equal); + + // i32::MIN < -284 + b = b.slice(1, 3); // slice to full length + assert_eq!(a.cmp(&b), Ordering::Less); + + // 79 > -284 + a = a.slice(1, 2); // slice to shorter than full length + assert!(!a.is_null()); + assert_eq!(a.cmp(&b), Ordering::Greater); + + // NULL == NULL + a = a.slice(1, 1); + b = b.slice(2, 1); + assert!(a.is_null()); + assert!(b.is_null()); + assert_eq!(a.cmp(&b), Ordering::Equal); + } }