From 6eda8be02570691e0311801f4862f8046ebd9f29 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 9 Apr 2026 14:07:54 -0400 Subject: [PATCH 01/11] Implement radix sort. --- arrow-row/src/lib.rs | 1 + arrow-row/src/radix.rs | 638 +++++++++++++++++++++++++++++++++++++++ arrow/benches/lexsort.rs | 16 + 3 files changed, 655 insertions(+) create mode 100644 arrow-row/src/radix.rs diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 078c4574775d..db35632d3e86 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -180,6 +180,7 @@ use arrow_array::types::{Int16Type, Int32Type, Int64Type}; mod fixed; mod list; +pub mod radix; mod run; mod variable; diff --git a/arrow-row/src/radix.rs b/arrow-row/src/radix.rs new file mode 100644 index 000000000000..4e099de9618c --- /dev/null +++ b/arrow-row/src/radix.rs @@ -0,0 +1,638 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! MSD radix sort on row-encoded keys. +//! +//! The Arrow row format produces big-endian, memcmp-comparable byte sequences, +//! making it ideal for MSD (Most Significant Digit) radix sort without any +//! additional encoding. This gives O(n × key_width) performance instead of +//! O(n log n × comparison_cost). +//! +//! # When to use this +//! +//! Radix sort on row-encoded keys is the fastest sort strategy when: +//! - **Primitive columns** (integers, floats): ~2.4x faster than [`lexsort_to_indices`] +//! at N=32768 despite the encoding overhead. +//! - **String columns**: 1.3–1.9x faster than the best alternative at all sizes. +//! The advantage grows with more string columns and larger N. +//! - **Mixed dict + string columns**: ~1.3x faster than row-format comparison sort. +//! - **List columns with other columns**: Competitive or faster when lists aren't +//! the primary sort key. +//! +//! # When NOT to use this +//! +//! - **Low-cardinality dictionary-only columns**: MSD radix sort performs poorly +//! when many rows share long identical key prefixes, because it must recurse +//! byte-by-byte through the shared prefix before reaching a distinguishing byte. +//! For example, 2 low-cardinality dict columns is ~7x *slower* than +//! [`lexsort_to_indices`]. Mixing in a string or primitive column eliminates +//! this problem. +//! - **List-heavy sorts where a leading primitive column discriminates most rows**: +//! [`lexsort_to_indices`] avoids encoding the expensive list column entirely for +//! most rows, so it wins when the leading column is highly selective. +//! +//! As a rule of thumb: if your sort key is dominated by low-cardinality columns +//! with no high-cardinality column to break ties early, prefer [`lexsort_to_indices`]. +//! Otherwise, radix sort is likely faster. +//! +//! [`lexsort_to_indices`]: arrow_ord::sort::lexsort_to_indices + +use crate::Rows; + +/// Bucket size at which we fall back to comparison sort. +const FALLBACK_THRESHOLD: usize = 64; + +/// Maximum byte depth before falling back to comparison sort. +const MAX_DEPTH: usize = 128; + +/// Sort row indices using MSD radix sort on row-encoded keys. +/// +/// Takes [`Rows`] produced by [`RowConverter::convert_columns`] and returns +/// a `Vec` of row indices in sorted order. The caller is responsible for +/// encoding columns into row format and for using the returned indices to +/// reorder the original arrays (e.g., via [`take`]). +/// +/// See the [module-level documentation](self) for guidance on when radix sort +/// is faster than [`lexsort_to_indices`]. +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_row::{RowConverter, SortField}; +/// # use arrow_row::radix::radix_sort_to_indices; +/// # use arrow_array::{Int32Array, ArrayRef}; +/// # use arrow_schema::DataType; +/// let array: ArrayRef = Arc::new(Int32Array::from(vec![5, 3, 1, 4, 2])); +/// let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); +/// let rows = converter.convert_columns(&[array]).unwrap(); +/// let indices = radix_sort_to_indices(&rows); +/// assert_eq!(indices, vec![2, 4, 1, 3, 0]); // points to [1, 2, 3, 4, 5] +/// ``` +/// +/// [`RowConverter::convert_columns`]: crate::RowConverter::convert_columns +/// [`take`]: https://docs.rs/arrow/latest/arrow/compute/fn.take.html +/// [`lexsort_to_indices`]: arrow_ord::sort::lexsort_to_indices +pub fn radix_sort_to_indices(rows: &Rows) -> Vec { + let n = rows.num_rows(); + let mut indices: Vec = (0..n as u32).collect(); + msd_radix_sort(&mut indices, rows, 0); + indices +} + +fn msd_radix_sort(indices: &mut [u32], rows: &Rows, byte_pos: usize) { + if indices.len() <= FALLBACK_THRESHOLD || byte_pos >= MAX_DEPTH { + indices.sort_unstable_by(|&a, &b| { + let ra = unsafe { rows.row_unchecked(a as usize) }; + let rb = unsafe { rows.row_unchecked(b as usize) }; + ra.cmp(&rb) + }); + return; + } + + // Histogram of byte values at byte_pos + let mut counts = [0u32; 256]; + for &idx in indices.iter() { + let row = unsafe { rows.row_unchecked(idx as usize) }; + let byte = row.data().get(byte_pos).copied().unwrap_or(0); + counts[byte as usize] += 1; + } + + // All same byte — skip to next position + if counts.iter().filter(|&&c| c > 0).count() == 1 { + msd_radix_sort(indices, rows, byte_pos + 1); + return; + } + + // Prefix sum for bucket offsets + let mut offsets = [0u32; 257]; + for i in 0..256 { + offsets[i + 1] = offsets[i] + counts[i]; + } + + // Out-of-place scatter into buckets + let mut temp = vec![0u32; indices.len()]; + let mut write_pos = offsets; + for &idx in indices.iter() { + let row = unsafe { rows.row_unchecked(idx as usize) }; + let byte = row.data().get(byte_pos).copied().unwrap_or(0) as usize; + temp[write_pos[byte] as usize] = idx; + write_pos[byte] += 1; + } + indices.copy_from_slice(&temp); + + // Recurse into non-trivial buckets + for bucket in 0..256 { + let start = offsets[bucket] as usize; + let end = offsets[bucket + 1] as usize; + if end - start > 1 { + msd_radix_sort(&mut indices[start..end], rows, byte_pos + 1); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{RowConverter, SortField}; + use arrow_array::{ + ArrayRef, BooleanArray, Float64Array, Int32Array, StringArray, + }; + use arrow_schema::{DataType, SortOptions}; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::sync::Arc; + + fn assert_sorted(rows: &Rows, indices: &[u32]) { + for i in 1..indices.len() { + let a = unsafe { rows.row_unchecked(indices[i - 1] as usize) }; + let b = unsafe { rows.row_unchecked(indices[i] as usize) }; + assert!(a <= b, "row {} should be <= row {}", indices[i - 1], indices[i]); + } + } + + #[test] + fn test_radix_sort_integers() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![5, 3, 1, 4, 2])); + let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + assert_eq!(indices, vec![2, 4, 1, 3, 0]); + } + + #[test] + fn test_radix_sort_strings() { + let array: ArrayRef = Arc::new(StringArray::from(vec!["banana", "apple", "cherry", "date"])); + let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + assert_eq!(indices, vec![1, 0, 2, 3]); + } + + #[test] + fn test_radix_sort_with_nulls() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(3), + None, + Some(1), + None, + Some(2), + ])); + let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + } + + #[test] + fn test_radix_sort_multi_column() { + let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 1, 2, 2])); + let a2: ArrayRef = Arc::new(StringArray::from(vec!["b", "a", "d", "c"])); + let converter = RowConverter::new(vec![ + SortField::new(DataType::Int32), + SortField::new(DataType::Utf8), + ]) + .unwrap(); + let rows = converter.convert_columns(&[a1, a2]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + assert_eq!(indices, vec![1, 0, 3, 2]); + } + + #[test] + fn test_radix_sort_empty() { + let array: ArrayRef = Arc::new(Int32Array::from(Vec::::new())); + let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert!(indices.is_empty()); + } + + #[test] + fn test_radix_sort_single_element() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![42])); + let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_eq!(indices, vec![0]); + } + + #[test] + fn test_radix_sort_all_equal() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![7, 7, 7, 7, 7])); + let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + } + + #[test] + fn test_radix_sort_descending() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 2, 5, 4])); + let options = SortOptions::default().desc(); + let converter = + RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]) + .unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + } + + #[test] + fn test_radix_sort_nulls_first() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(3), + None, + Some(1), + None, + Some(2), + ])); + let options = SortOptions::default().with_nulls_first(true); + let converter = + RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]) + .unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + } + + #[test] + fn test_radix_sort_descending_nulls_first() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(3), + None, + Some(1), + None, + Some(2), + ])); + let options = SortOptions::default().desc().with_nulls_first(true); + let converter = + RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]) + .unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + } + + #[test] + fn test_radix_sort_all_sort_option_combos() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(5), + None, + Some(3), + None, + Some(1), + Some(4), + Some(2), + ])); + + for descending in [false, true] { + for nulls_first in [false, true] { + let options = SortOptions { + descending, + nulls_first, + }; + let converter = RowConverter::new(vec![SortField::new_with_options( + DataType::Int32, + options, + )]) + .unwrap(); + let rows = converter.convert_columns(&[array.clone()]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + } + } + } + + #[test] + fn test_radix_sort_floats_with_nan() { + let array: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(1.0), + Some(f64::NAN), + None, + Some(-1.0), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(0.0), + ])); + let converter = RowConverter::new(vec![SortField::new(DataType::Float64)]).unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + } + + #[test] + fn test_radix_sort_booleans() { + let array: ArrayRef = Arc::new(BooleanArray::from(vec![ + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + ])); + let converter = RowConverter::new(vec![SortField::new(DataType::Boolean)]).unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + } + + // Tests sizes around the FALLBACK_THRESHOLD (64) to exercise both paths + #[test] + fn test_radix_sort_threshold_boundary() { + let mut rng = StdRng::seed_from_u64(0xCAFE); + for n in [1, 2, 32, 63, 64, 65, 100, 128, 256, 500, 1000] { + let values: Vec> = (0..n) + .map(|_| { + if rng.random_bool(0.1) { + None + } else { + Some(rng.random_range(-1000..1000)) + } + }) + .collect(); + let array: ArrayRef = Arc::new(Int32Array::from(values)); + let converter = + RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + assert_eq!(indices.len(), n, "wrong number of indices for n={n}"); + } + } + + /// Generate random arrays and verify radix sort produces correctly sorted output + /// across random column types, sort options, and sizes. + #[test] + fn test_radix_sort_fuzz() { + let mut rng = StdRng::seed_from_u64(0xF00D); + + for iteration in 0..100 { + let num_columns = rng.random_range(1..=4); + let len = rng.random_range(5..500); + + let mut arrays: Vec = Vec::new(); + let mut fields: Vec = Vec::new(); + + for _ in 0..num_columns { + let options = SortOptions { + descending: rng.random_bool(0.5), + nulls_first: rng.random_bool(0.5), + }; + let null_rate = if rng.random_bool(0.3) { 0.0 } else { 0.2 }; + + // Pick a random column type + let (array, dt) = match rng.random_range(0..7) { + 0 => { + let vals: Vec> = (0..len) + .map(|_| { + if rng.random_bool(null_rate) { + None + } else { + Some(rng.random_range(-10000..10000)) + } + }) + .collect(); + ( + Arc::new(Int32Array::from(vals)) as ArrayRef, + DataType::Int32, + ) + } + 1 => { + let vals: Vec> = (0..len) + .map(|_| { + if rng.random_bool(null_rate) { + None + } else { + Some(rng.random()) + } + }) + .collect(); + ( + Arc::new(arrow_array::Int64Array::from(vals)) as ArrayRef, + DataType::Int64, + ) + } + 2 => { + let vals: Vec> = (0..len) + .map(|_| { + if rng.random_bool(null_rate) { + None + } else { + Some(rng.random::() * 1000.0 - 500.0) + } + }) + .collect(); + ( + Arc::new(Float64Array::from(vals)) as ArrayRef, + DataType::Float64, + ) + } + 3 => { + let vals: Vec> = (0..len) + .map(|_| { + if rng.random_bool(null_rate) { + None + } else { + // Fixed set of strings to get some collisions + Some( + ["alpha", "beta", "gamma", "delta", "epsilon", + "zeta", "eta", "theta", "iota", "kappa", + "a longer string for testing", ""] + [rng.random_range(0..12)], + ) + } + }) + .collect(); + ( + Arc::new(StringArray::from(vals)) as ArrayRef, + DataType::Utf8, + ) + } + 4 => { + let vals: Vec> = (0..len) + .map(|_| { + if rng.random_bool(null_rate) { + None + } else { + Some(rng.random_bool(0.5)) + } + }) + .collect(); + ( + Arc::new(BooleanArray::from(vals)) as ArrayRef, + DataType::Boolean, + ) + } + 5 => { + // Low-cardinality i32 to create many ties + let vals: Vec> = (0..len) + .map(|_| { + if rng.random_bool(null_rate) { + None + } else { + Some(rng.random_range(0..5)) + } + }) + .collect(); + ( + Arc::new(Int32Array::from(vals)) as ArrayRef, + DataType::Int32, + ) + } + _ => { + // All-null column + let vals: Vec> = vec![None; len]; + ( + Arc::new(Int32Array::from(vals)) as ArrayRef, + DataType::Int32, + ) + } + }; + + arrays.push(array); + fields.push(SortField::new_with_options(dt, options)); + } + + let converter = RowConverter::new(fields).unwrap(); + let rows = converter.convert_columns(&arrays).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + assert_eq!( + indices.len(), + len, + "iteration {iteration}: wrong index count" + ); + + // Verify every original index appears exactly once + let mut seen = vec![false; len]; + for &idx in &indices { + assert!( + !seen[idx as usize], + "iteration {iteration}: duplicate index {idx}" + ); + seen[idx as usize] = true; + } + } + } + + /// Verify radix sort matches comparison sort on row-encoded keys. + /// Uses the same Rows, so any difference is a bug in the radix sort itself. + #[test] + fn test_radix_matches_comparison_sort() { + let mut rng = StdRng::seed_from_u64(0xBEEF); + + for _ in 0..50 { + let len = rng.random_range(100..1000); + let vals: Vec> = (0..len) + .map(|_| { + if rng.random_bool(0.15) { + None + } else { + Some(rng.random_range(-500..500)) + } + }) + .collect(); + + let options = SortOptions { + descending: rng.random_bool(0.5), + nulls_first: rng.random_bool(0.5), + }; + + let array: ArrayRef = Arc::new(Int32Array::from(vals)); + let converter = RowConverter::new(vec![SortField::new_with_options( + DataType::Int32, + options, + )]) + .unwrap(); + let rows = converter.convert_columns(&[array]).unwrap(); + + let radix = radix_sort_to_indices(&rows); + + let mut comparison: Vec = (0..len as u32).collect(); + comparison.sort_unstable_by(|&a, &b| { + rows.row(a as usize).cmp(&rows.row(b as usize)) + }); + + // Both sorts operate on the same rows, so equal-keyed elements + // should appear in the same relative order only if both are stable + // (they aren't), but the *row values* at each position must match. + for i in 0..len { + let radix_row = rows.row(radix[i] as usize); + let cmp_row = rows.row(comparison[i] as usize); + assert_eq!( + radix_row, cmp_row, + "mismatch at position {i}: radix idx={} vs comparison idx={}", + radix[i], comparison[i] + ); + } + } + } + + /// Test with a multi-column schema that has mixed sort options per column. + #[test] + fn test_radix_sort_multi_column_mixed_options() { + let a1: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(1), + Some(2), + None, + ])); + let a2: ArrayRef = Arc::new(StringArray::from(vec![ + Some("z"), + Some("a"), + Some("a"), + None, + Some("m"), + ])); + + // Col 1: ascending nulls_last, Col 2: descending nulls_first + let converter = RowConverter::new(vec![ + SortField::new_with_options( + DataType::Int32, + SortOptions::default().asc().with_nulls_first(false), + ), + SortField::new_with_options( + DataType::Utf8, + SortOptions::default().desc().with_nulls_first(true), + ), + ]) + .unwrap(); + let rows = converter.convert_columns(&[a1, a2]).unwrap(); + + let indices = radix_sort_to_indices(&rows); + assert_sorted(&rows, &indices); + } +} diff --git a/arrow/benches/lexsort.rs b/arrow/benches/lexsort.rs index 16a2606b919a..f9028ef1ac99 100644 --- a/arrow/benches/lexsort.rs +++ b/arrow/benches/lexsort.rs @@ -16,6 +16,7 @@ // under the License. use arrow::compute::{SortColumn, lexsort_to_indices}; +use arrow::row::radix::radix_sort_to_indices; use arrow::row::{RowConverter, SortField}; use arrow::util::bench_util::{ create_dict_from_values, create_primitive_array, create_string_array_with_len, @@ -146,6 +147,21 @@ fn do_bench(c: &mut Criterion, columns: &[Column], len: usize) { }) }) }); + + c.bench_function(&format!("lexsort_radix({columns:?}): {len}"), |b| { + b.iter(|| { + hint::black_box({ + let fields = arrays + .iter() + .map(|a| SortField::new(a.data_type().clone())) + .collect(); + let converter = RowConverter::new(fields).unwrap(); + let rows = converter.convert_columns(&arrays).unwrap(); + let indices = radix_sort_to_indices(&rows); + UInt32Array::from_iter_values(indices) + }) + }) + }); } fn add_benchmark(c: &mut Criterion) { From 5fb42319334d72f1f428fb38beb140b6dbd9963b Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 9 Apr 2026 15:12:43 -0400 Subject: [PATCH 02/11] Update comments. --- arrow-row/src/radix.rs | 135 ++++++++++++++++++++------------------- arrow/benches/lexsort.rs | 6 -- 2 files changed, 70 insertions(+), 71 deletions(-) diff --git a/arrow-row/src/radix.rs b/arrow-row/src/radix.rs index 4e099de9618c..f08f6cfe9987 100644 --- a/arrow-row/src/radix.rs +++ b/arrow-row/src/radix.rs @@ -24,40 +24,42 @@ //! //! # When to use this //! -//! Radix sort on row-encoded keys is the fastest sort strategy when: -//! - **Primitive columns** (integers, floats): ~2.4x faster than [`lexsort_to_indices`] -//! at N=32768 despite the encoding overhead. -//! - **String columns**: 1.3–1.9x faster than the best alternative at all sizes. -//! The advantage grows with more string columns and larger N. -//! - **Mixed dict + string columns**: ~1.3x faster than row-format comparison sort. -//! - **List columns with other columns**: Competitive or faster when lists aren't -//! the primary sort key. +//! Radix sort on row-encoded keys is the fastest sort strategy for most +//! multi-column sorts, including: +//! - **Primitive columns** (integers, floats) +//! - **String columns**, especially multiple string columns +//! - **Mixed column types** (primitives, strings, dicts, lists) //! -//! # When NOT to use this +//! The advantage over [`lexsort_to_indices`] grows with N and with the +//! number of columns. //! -//! - **Low-cardinality dictionary-only columns**: MSD radix sort performs poorly -//! when many rows share long identical key prefixes, because it must recurse -//! byte-by-byte through the shared prefix before reaching a distinguishing byte. -//! For example, 2 low-cardinality dict columns is ~7x *slower* than -//! [`lexsort_to_indices`]. Mixing in a string or primitive column eliminates -//! this problem. -//! - **List-heavy sorts where a leading primitive column discriminates most rows**: -//! [`lexsort_to_indices`] avoids encoding the expensive list column entirely for -//! most rows, so it wins when the leading column is highly selective. +//! # When NOT to use this //! -//! As a rule of thumb: if your sort key is dominated by low-cardinality columns -//! with no high-cardinality column to break ties early, prefer [`lexsort_to_indices`]. -//! Otherwise, radix sort is likely faster. +//! Prefer [`lexsort_to_indices`] when: +//! - **All sort columns are low-cardinality dictionaries** with no +//! high-cardinality column to break ties. The row encoding for +//! dictionary values produces long shared prefixes, and radix sort +//! gains little from its first few byte passes before falling back +//! to comparison sort. +//! - **A leading primitive column discriminates most rows and a trailing +//! column is expensive to encode** (e.g., lists). [`lexsort_to_indices`] +//! avoids encoding the trailing column for rows already resolved by +//! the leading column. //! -//! [`lexsort_to_indices`]: arrow_ord::sort::lexsort_to_indices +//! [`lexsort_to_indices`]: https://docs.rs/arrow-ord/latest/arrow_ord/sort/fn.lexsort_to_indices.html use crate::Rows; -/// Bucket size at which we fall back to comparison sort. +/// When a bucket has this few elements, the fixed per-level cost of radix +/// sort (256-bucket histogram + scatter) exceeds the O(n log n) cost of +/// comparison sort with small n and warm cache lines. const FALLBACK_THRESHOLD: usize = 64; -/// Maximum byte depth before falling back to comparison sort. -const MAX_DEPTH: usize = 128; +/// Beyond this depth, comparison sort on the full row handles the +/// remaining discrimination. 8 bytes covers the discriminating prefix +/// of most key layouts; deeper recursion hits diminishing returns as +/// buckets become sparse and the per-level overhead dominates. +const MAX_DEPTH: usize = 8; /// Sort row indices using MSD radix sort on row-encoded keys. /// @@ -86,7 +88,7 @@ const MAX_DEPTH: usize = 128; /// /// [`RowConverter::convert_columns`]: crate::RowConverter::convert_columns /// [`take`]: https://docs.rs/arrow/latest/arrow/compute/fn.take.html -/// [`lexsort_to_indices`]: arrow_ord::sort::lexsort_to_indices +/// [`lexsort_to_indices`]: https://docs.rs/arrow-ord/latest/arrow_ord/sort/fn.lexsort_to_indices.html pub fn radix_sort_to_indices(rows: &Rows) -> Vec { let n = rows.num_rows(); let mut indices: Vec = (0..n as u32).collect(); @@ -104,30 +106,30 @@ fn msd_radix_sort(indices: &mut [u32], rows: &Rows, byte_pos: usize) { return; } - // Histogram of byte values at byte_pos let mut counts = [0u32; 256]; - for &idx in indices.iter() { + // Reborrow as &[u32] so pattern gives u32 not &mut u32 + for &idx in &*indices { let row = unsafe { rows.row_unchecked(idx as usize) }; let byte = row.data().get(byte_pos).copied().unwrap_or(0); counts[byte as usize] += 1; } - // All same byte — skip to next position + // No discrimination at this byte position — all rows have the same value if counts.iter().filter(|&&c| c > 0).count() == 1 { msd_radix_sort(indices, rows, byte_pos + 1); return; } - // Prefix sum for bucket offsets let mut offsets = [0u32; 257]; for i in 0..256 { offsets[i + 1] = offsets[i] + counts[i]; } - // Out-of-place scatter into buckets + // Out-of-place scatter avoids the complexity of in-place partitioning + // across 256 buckets let mut temp = vec![0u32; indices.len()]; let mut write_pos = offsets; - for &idx in indices.iter() { + for &idx in &*indices { let row = unsafe { rows.row_unchecked(idx as usize) }; let byte = row.data().get(byte_pos).copied().unwrap_or(0) as usize; temp[write_pos[byte] as usize] = idx; @@ -135,7 +137,6 @@ fn msd_radix_sort(indices: &mut [u32], rows: &Rows, byte_pos: usize) { } indices.copy_from_slice(&temp); - // Recurse into non-trivial buckets for bucket in 0..256 { let start = offsets[bucket] as usize; let end = offsets[bucket + 1] as usize; @@ -149,9 +150,7 @@ fn msd_radix_sort(indices: &mut [u32], rows: &Rows, byte_pos: usize) { mod tests { use super::*; use crate::{RowConverter, SortField}; - use arrow_array::{ - ArrayRef, BooleanArray, Float64Array, Int32Array, StringArray, - }; + use arrow_array::{ArrayRef, BooleanArray, Float64Array, Int32Array, StringArray}; use arrow_schema::{DataType, SortOptions}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -159,9 +158,14 @@ mod tests { fn assert_sorted(rows: &Rows, indices: &[u32]) { for i in 1..indices.len() { - let a = unsafe { rows.row_unchecked(indices[i - 1] as usize) }; - let b = unsafe { rows.row_unchecked(indices[i] as usize) }; - assert!(a <= b, "row {} should be <= row {}", indices[i - 1], indices[i]); + let a = rows.row(indices[i - 1] as usize); + let b = rows.row(indices[i] as usize); + assert!( + a <= b, + "row {} should be <= row {}", + indices[i - 1], + indices[i] + ); } } @@ -178,7 +182,8 @@ mod tests { #[test] fn test_radix_sort_strings() { - let array: ArrayRef = Arc::new(StringArray::from(vec!["banana", "apple", "cherry", "date"])); + let array: ArrayRef = + Arc::new(StringArray::from(vec!["banana", "apple", "cherry", "date"])); let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); let rows = converter.convert_columns(&[array]).unwrap(); @@ -254,8 +259,7 @@ mod tests { let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 2, 5, 4])); let options = SortOptions::default().desc(); let converter = - RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]) - .unwrap(); + RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]).unwrap(); let rows = converter.convert_columns(&[array]).unwrap(); let indices = radix_sort_to_indices(&rows); @@ -273,8 +277,7 @@ mod tests { ])); let options = SortOptions::default().with_nulls_first(true); let converter = - RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]) - .unwrap(); + RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]).unwrap(); let rows = converter.convert_columns(&[array]).unwrap(); let indices = radix_sort_to_indices(&rows); @@ -292,8 +295,7 @@ mod tests { ])); let options = SortOptions::default().desc().with_nulls_first(true); let converter = - RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]) - .unwrap(); + RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]).unwrap(); let rows = converter.convert_columns(&[array]).unwrap(); let indices = radix_sort_to_indices(&rows); @@ -318,11 +320,9 @@ mod tests { descending, nulls_first, }; - let converter = RowConverter::new(vec![SortField::new_with_options( - DataType::Int32, - options, - )]) - .unwrap(); + let converter = + RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]) + .unwrap(); let rows = converter.convert_columns(&[array.clone()]).unwrap(); let indices = radix_sort_to_indices(&rows); @@ -381,8 +381,7 @@ mod tests { }) .collect(); let array: ArrayRef = Arc::new(Int32Array::from(values)); - let converter = - RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); + let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); let rows = converter.convert_columns(&[array]).unwrap(); let indices = radix_sort_to_indices(&rows); @@ -466,10 +465,20 @@ mod tests { } else { // Fixed set of strings to get some collisions Some( - ["alpha", "beta", "gamma", "delta", "epsilon", - "zeta", "eta", "theta", "iota", "kappa", - "a longer string for testing", ""] - [rng.random_range(0..12)], + [ + "alpha", + "beta", + "gamma", + "delta", + "epsilon", + "zeta", + "eta", + "theta", + "iota", + "kappa", + "a longer string for testing", + "", + ][rng.random_range(0..12)], ) } }) @@ -571,19 +580,15 @@ mod tests { }; let array: ArrayRef = Arc::new(Int32Array::from(vals)); - let converter = RowConverter::new(vec![SortField::new_with_options( - DataType::Int32, - options, - )]) - .unwrap(); + let converter = + RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]) + .unwrap(); let rows = converter.convert_columns(&[array]).unwrap(); let radix = radix_sort_to_indices(&rows); let mut comparison: Vec = (0..len as u32).collect(); - comparison.sort_unstable_by(|&a, &b| { - rows.row(a as usize).cmp(&rows.row(b as usize)) - }); + comparison.sort_unstable_by(|&a, &b| rows.row(a as usize).cmp(&rows.row(b as usize))); // Both sorts operate on the same rows, so equal-keyed elements // should appear in the same relative order only if both are stable diff --git a/arrow/benches/lexsort.rs b/arrow/benches/lexsort.rs index f9028ef1ac99..b7adc59fa6b1 100644 --- a/arrow/benches/lexsort.rs +++ b/arrow/benches/lexsort.rs @@ -202,12 +202,6 @@ fn add_benchmark(c: &mut Criterion) { Column::Optional100Value50CharStringDict, Column::Optional50CharString, ], - &[ - Column::Optional100Value50CharStringDict, - Column::Optional100Value50CharStringDict, - Column::Optional100Value50CharStringDict, - Column::Optional50CharString, - ], &[Column::OptionalI32, Column::RequiredI32List], &[Column::OptionalI32, Column::OptionalI32List], &[Column::OptionalI32List, Column::OptionalI32], From 026a2540cbf71c250ab66839ffd774cb42cefd97 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 9 Apr 2026 15:46:46 -0400 Subject: [PATCH 03/11] Hoist temp buffer out. --- arrow-row/src/radix.rs | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/arrow-row/src/radix.rs b/arrow-row/src/radix.rs index f08f6cfe9987..b3044ab59f92 100644 --- a/arrow-row/src/radix.rs +++ b/arrow-row/src/radix.rs @@ -92,13 +92,17 @@ const MAX_DEPTH: usize = 8; pub fn radix_sort_to_indices(rows: &Rows) -> Vec { let n = rows.num_rows(); let mut indices: Vec = (0..n as u32).collect(); - msd_radix_sort(&mut indices, rows, 0); + let mut temp = vec![0u32; n]; + msd_radix_sort(&mut indices, &mut temp, rows, 0); indices } -fn msd_radix_sort(indices: &mut [u32], rows: &Rows, byte_pos: usize) { - if indices.len() <= FALLBACK_THRESHOLD || byte_pos >= MAX_DEPTH { +fn msd_radix_sort(indices: &mut [u32], temp: &mut [u32], rows: &Rows, byte_pos: usize) { + let n = indices.len(); + + if n <= FALLBACK_THRESHOLD || byte_pos >= MAX_DEPTH { indices.sort_unstable_by(|&a, &b| { + // SAFETY: indices contains a permutation of 0..rows.num_rows() let ra = unsafe { rows.row_unchecked(a as usize) }; let rb = unsafe { rows.row_unchecked(b as usize) }; ra.cmp(&rb) @@ -106,9 +110,13 @@ fn msd_radix_sort(indices: &mut [u32], rows: &Rows, byte_pos: usize) { return; } + // Both the histogram and scatter loops read each row's byte via + // row_unchecked. Pre-extracting bytes into a contiguous buffer was + // tried but benchmarked slower — the extra write pass costs more + // than the second read through row offsets already hot in cache. let mut counts = [0u32; 256]; - // Reborrow as &[u32] so pattern gives u32 not &mut u32 for &idx in &*indices { + // SAFETY: indices contains a permutation of 0..rows.num_rows() let row = unsafe { rows.row_unchecked(idx as usize) }; let byte = row.data().get(byte_pos).copied().unwrap_or(0); counts[byte as usize] += 1; @@ -116,7 +124,7 @@ fn msd_radix_sort(indices: &mut [u32], rows: &Rows, byte_pos: usize) { // No discrimination at this byte position — all rows have the same value if counts.iter().filter(|&&c| c > 0).count() == 1 { - msd_radix_sort(indices, rows, byte_pos + 1); + msd_radix_sort(indices, temp, rows, byte_pos + 1); return; } @@ -127,21 +135,27 @@ fn msd_radix_sort(indices: &mut [u32], rows: &Rows, byte_pos: usize) { // Out-of-place scatter avoids the complexity of in-place partitioning // across 256 buckets - let mut temp = vec![0u32; indices.len()]; + let temp = &mut temp[..n]; let mut write_pos = offsets; for &idx in &*indices { + // SAFETY: indices contains a permutation of 0..rows.num_rows() let row = unsafe { rows.row_unchecked(idx as usize) }; let byte = row.data().get(byte_pos).copied().unwrap_or(0) as usize; temp[write_pos[byte] as usize] = idx; write_pos[byte] += 1; } - indices.copy_from_slice(&temp); + indices.copy_from_slice(temp); for bucket in 0..256 { let start = offsets[bucket] as usize; let end = offsets[bucket + 1] as usize; if end - start > 1 { - msd_radix_sort(&mut indices[start..end], rows, byte_pos + 1); + msd_radix_sort( + &mut indices[start..end], + &mut temp[start..end], + rows, + byte_pos + 1, + ); } } } From c8b7a96f07cb14dd1e9624e8c80a756abeac409a Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 9 Apr 2026 16:00:14 -0400 Subject: [PATCH 04/11] Clarify confusing comment. --- arrow-row/src/radix.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arrow-row/src/radix.rs b/arrow-row/src/radix.rs index b3044ab59f92..dc4b4497e111 100644 --- a/arrow-row/src/radix.rs +++ b/arrow-row/src/radix.rs @@ -122,7 +122,8 @@ fn msd_radix_sort(indices: &mut [u32], temp: &mut [u32], rows: &Rows, byte_pos: counts[byte as usize] += 1; } - // No discrimination at this byte position — all rows have the same value + // Skip scatter when all rows share the same byte — one bucket + // with everything in it is just wasted work if counts.iter().filter(|&&c| c > 0).count() == 1 { msd_radix_sort(indices, temp, rows, byte_pos + 1); return; From 8401ebeec666fcacbc3b0d041b08015432b9845b Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 9 Apr 2026 16:10:29 -0400 Subject: [PATCH 05/11] Fix clippy. --- arrow-row/src/radix.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arrow-row/src/radix.rs b/arrow-row/src/radix.rs index dc4b4497e111..9d073897c322 100644 --- a/arrow-row/src/radix.rs +++ b/arrow-row/src/radix.rs @@ -338,7 +338,9 @@ mod tests { let converter = RowConverter::new(vec![SortField::new_with_options(DataType::Int32, options)]) .unwrap(); - let rows = converter.convert_columns(&[array.clone()]).unwrap(); + let rows = converter + .convert_columns(std::slice::from_ref(&array)) + .unwrap(); let indices = radix_sort_to_indices(&rows); assert_sorted(&rows, &indices); From f9a4c41ab095f4d7e037e9f4c89e32d2d1d94b19 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 9 Apr 2026 20:33:16 -0400 Subject: [PATCH 06/11] Address PR feedback. --- arrow-row/src/lib.rs | 25 +++++++++ arrow-row/src/radix.rs | 112 ++++++++++++++++++++++++++++++----------- 2 files changed, 107 insertions(+), 30 deletions(-) diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index db35632d3e86..e6c2293f7ce4 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -1452,6 +1452,31 @@ impl<'a> Row<'a> { pub fn data(&self) -> &'a [u8] { self.data } + + /// The byte at `offset`, or 0 if `offset` is past the end of the row. + #[inline] + pub fn byte_from(&self, offset: usize) -> u8 { + if offset < self.data.len() { + // SAFETY: bounds checked above + unsafe { *self.data.get_unchecked(offset) } + } else { + 0 + } + } + + /// The row's bytes starting at `offset`, or an empty slice if + /// `offset` exceeds the row length. + /// + /// Useful for comparing rows that share a known prefix (e.g., + /// after radix sort has already discriminated on earlier bytes). + pub fn data_from(&self, offset: usize) -> &'a [u8] { + if offset <= self.data.len() { + // SAFETY: bounds checked above + unsafe { self.data.get_unchecked(offset..) } + } else { + &[] + } + } } // Manually derive these as don't wish to include `fields` diff --git a/arrow-row/src/radix.rs b/arrow-row/src/radix.rs index 9d073897c322..78401b45bc6d 100644 --- a/arrow-row/src/radix.rs +++ b/arrow-row/src/radix.rs @@ -93,20 +93,66 @@ pub fn radix_sort_to_indices(rows: &Rows) -> Vec { let n = rows.num_rows(); let mut indices: Vec = (0..n as u32).collect(); let mut temp = vec![0u32; n]; - msd_radix_sort(&mut indices, &mut temp, rows, 0); + let mut bytes = vec![0u8; n]; + msd_radix_sort(&mut indices, &mut temp, &mut bytes, rows, 0, true); indices } -fn msd_radix_sort(indices: &mut [u32], temp: &mut [u32], rows: &Rows, byte_pos: usize) { - let n = indices.len(); +/// Returns the byte at `byte_pos` for the row at `idx`, or 0 if the row +/// is shorter. +/// +/// # Safety +/// `idx` must be a valid row index in `rows`. +#[inline(always)] +unsafe fn row_byte(rows: &Rows, idx: u32, byte_pos: usize) -> u8 { + // SAFETY: caller guarantees idx is a valid row index + unsafe { rows.row_unchecked(idx as usize) }.byte_from(byte_pos) +} + +/// MSD radix sort using ping-pong buffers. +/// +/// Each level scatters from `src` into `dst`, then recurses with the +/// roles swapped (dst becomes the next level's src). This avoids an +/// O(n) `copy_from_slice` at every recursion level. +/// +/// `result_in_src` tracks where the caller expects the sorted output: +/// true means `src`, false means `dst`. It flips at each scatter so +/// the final result lands in the right buffer. The top-level call +/// passes `true` so the answer ends up in `indices`. +fn msd_radix_sort( + src: &mut [u32], + dst: &mut [u32], + rows: &Rows, + byte_pos: usize, + result_in_src: bool, +) { + let n = src.len(); if n <= FALLBACK_THRESHOLD || byte_pos >= MAX_DEPTH { - indices.sort_unstable_by(|&a, &b| { - // SAFETY: indices contains a permutation of 0..rows.num_rows() - let ra = unsafe { rows.row_unchecked(a as usize) }; - let rb = unsafe { rows.row_unchecked(b as usize) }; - ra.cmp(&rb) - }); + // Compare only from byte_pos onward — earlier bytes are identical + // within this bucket, having already been discriminated by radix + // passes above us. Safe slice via get() is needed because rows of + // different lengths can share a bucket when a shorter row's + // past-end default (0) matches a longer row's real byte value. + // + // When !result_in_src the caller expects the output in dst, so + // we copy first and sort in place there. + if result_in_src { + src.sort_unstable_by(|&a, &b| { + // SAFETY: indices contains a permutation of 0..rows.num_rows() + let ra = unsafe { rows.row_unchecked(a as usize) }; + let rb = unsafe { rows.row_unchecked(b as usize) }; + ra.data_from(byte_pos).cmp(rb.data_from(byte_pos)) + }); + } else { + dst.copy_from_slice(src); + dst.sort_unstable_by(|&a, &b| { + // SAFETY: indices contains a permutation of 0..rows.num_rows() + let ra = unsafe { rows.row_unchecked(a as usize) }; + let rb = unsafe { rows.row_unchecked(b as usize) }; + ra.data_from(byte_pos).cmp(rb.data_from(byte_pos)) + }); + } return; } @@ -115,48 +161,54 @@ fn msd_radix_sort(indices: &mut [u32], temp: &mut [u32], rows: &Rows, byte_pos: // tried but benchmarked slower — the extra write pass costs more // than the second read through row offsets already hot in cache. let mut counts = [0u32; 256]; - for &idx in &*indices { + for &idx in &*src { // SAFETY: indices contains a permutation of 0..rows.num_rows() - let row = unsafe { rows.row_unchecked(idx as usize) }; - let byte = row.data().get(byte_pos).copied().unwrap_or(0); + let byte = unsafe { row_byte(rows, idx, byte_pos) }; counts[byte as usize] += 1; } - // Skip scatter when all rows share the same byte — one bucket - // with everything in it is just wasted work - if counts.iter().filter(|&&c| c > 0).count() == 1 { - msd_radix_sort(indices, temp, rows, byte_pos + 1); - return; - } - let mut offsets = [0u32; 257]; + let mut num_buckets = 0u32; for i in 0..256 { + num_buckets += (counts[i] > 0) as u32; offsets[i + 1] = offsets[i] + counts[i]; } - // Out-of-place scatter avoids the complexity of in-place partitioning - // across 256 buckets - let temp = &mut temp[..n]; + // No scatter happened — data is still in src, roles unchanged. + if num_buckets == 1 { + msd_radix_sort(src, dst, rows, byte_pos + 1, result_in_src); + return; + } + + // Scatter src → dst let mut write_pos = offsets; - for &idx in &*indices { + for &idx in &*src { // SAFETY: indices contains a permutation of 0..rows.num_rows() - let row = unsafe { rows.row_unchecked(idx as usize) }; - let byte = row.data().get(byte_pos).copied().unwrap_or(0) as usize; - temp[write_pos[byte] as usize] = idx; + let byte = unsafe { row_byte(rows, idx, byte_pos) } as usize; + dst[write_pos[byte] as usize] = idx; write_pos[byte] += 1; } - indices.copy_from_slice(temp); + // Recurse with roles swapped: after scatter the data lives in dst, + // so dst becomes the next level's src. Flipping result_in_src + // ensures each level's output lands where the caller above expects. for bucket in 0..256 { let start = offsets[bucket] as usize; let end = offsets[bucket + 1] as usize; - if end - start > 1 { + let len = end - start; + if len > 1 { msd_radix_sort( - &mut indices[start..end], - &mut temp[start..end], + &mut dst[start..end], + &mut src[start..end], rows, byte_pos + 1, + !result_in_src, ); + } else if len == 1 && result_in_src { + // Single-element bucket doesn't recurse. After scatter + // the element is in dst; copy it back if the caller + // expects the result in src. + src[start] = dst[start]; } } } From 89ceaf1ad140f90623fc530ab3f7154c4723574b Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 9 Apr 2026 20:33:38 -0400 Subject: [PATCH 07/11] Address PR feedback. --- arrow-row/src/radix.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/arrow-row/src/radix.rs b/arrow-row/src/radix.rs index 78401b45bc6d..ca76219b173f 100644 --- a/arrow-row/src/radix.rs +++ b/arrow-row/src/radix.rs @@ -93,8 +93,7 @@ pub fn radix_sort_to_indices(rows: &Rows) -> Vec { let n = rows.num_rows(); let mut indices: Vec = (0..n as u32).collect(); let mut temp = vec![0u32; n]; - let mut bytes = vec![0u8; n]; - msd_radix_sort(&mut indices, &mut temp, &mut bytes, rows, 0, true); + msd_radix_sort(&mut indices, &mut temp, rows, 0, true); indices } From 15351fe0908e428f036b2b6958d1dbfd2dcd271f Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 9 Apr 2026 21:13:01 -0400 Subject: [PATCH 08/11] Byte buffer extraction. --- arrow-row/src/radix.rs | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/arrow-row/src/radix.rs b/arrow-row/src/radix.rs index ca76219b173f..e4908389b3d4 100644 --- a/arrow-row/src/radix.rs +++ b/arrow-row/src/radix.rs @@ -93,7 +93,8 @@ pub fn radix_sort_to_indices(rows: &Rows) -> Vec { let n = rows.num_rows(); let mut indices: Vec = (0..n as u32).collect(); let mut temp = vec![0u32; n]; - msd_radix_sort(&mut indices, &mut temp, rows, 0, true); + let mut bytes = vec![0u8; n]; + msd_radix_sort(&mut indices, &mut temp, &mut bytes, rows, 0, true); indices } @@ -121,6 +122,7 @@ unsafe fn row_byte(rows: &Rows, idx: u32, byte_pos: usize) -> u8 { fn msd_radix_sort( src: &mut [u32], dst: &mut [u32], + bytes: &mut [u8], rows: &Rows, byte_pos: usize, result_in_src: bool, @@ -155,15 +157,16 @@ fn msd_radix_sort( return; } - // Both the histogram and scatter loops read each row's byte via - // row_unchecked. Pre-extracting bytes into a contiguous buffer was - // tried but benchmarked slower — the extra write pass costs more - // than the second read through row offsets already hot in cache. + // Extract bytes and build histogram in one pass. The bytes buffer + // is reused across levels so the scatter loop can read from a flat + // array instead of chasing pointers through Rows a second time. + let bytes = &mut bytes[..n]; let mut counts = [0u32; 256]; - for &idx in &*src { - // SAFETY: indices contains a permutation of 0..rows.num_rows() - let byte = unsafe { row_byte(rows, idx, byte_pos) }; - counts[byte as usize] += 1; + for (i, &idx) in src.iter().enumerate() { + // SAFETY: src contains valid row indices + let b = unsafe { row_byte(rows, idx, byte_pos) }; + bytes[i] = b; + counts[b as usize] += 1; } let mut offsets = [0u32; 257]; @@ -175,17 +178,16 @@ fn msd_radix_sort( // No scatter happened — data is still in src, roles unchanged. if num_buckets == 1 { - msd_radix_sort(src, dst, rows, byte_pos + 1, result_in_src); + msd_radix_sort(src, dst, bytes, rows, byte_pos + 1, result_in_src); return; } - // Scatter src → dst + // Scatter src → dst using the pre-extracted bytes let mut write_pos = offsets; - for &idx in &*src { - // SAFETY: indices contains a permutation of 0..rows.num_rows() - let byte = unsafe { row_byte(rows, idx, byte_pos) } as usize; - dst[write_pos[byte] as usize] = idx; - write_pos[byte] += 1; + for (i, &idx) in src.iter().enumerate() { + let b = bytes[i] as usize; + dst[write_pos[b] as usize] = idx; + write_pos[b] += 1; } // Recurse with roles swapped: after scatter the data lives in dst, @@ -199,6 +201,7 @@ fn msd_radix_sort( msd_radix_sort( &mut dst[start..end], &mut src[start..end], + &mut bytes[start..end], rows, byte_pos + 1, !result_in_src, From 4c4498f7408dec5ae1a50bb5904022791f91047e Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 10 Apr 2026 07:44:17 -0400 Subject: [PATCH 09/11] Update defaults based on sensitivity analysis. --- arrow-row/src/radix.rs | 64 +++++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 14 deletions(-) diff --git a/arrow-row/src/radix.rs b/arrow-row/src/radix.rs index e4908389b3d4..aa8c15fee636 100644 --- a/arrow-row/src/radix.rs +++ b/arrow-row/src/radix.rs @@ -50,15 +50,19 @@ use crate::Rows; -/// When a bucket has this few elements, the fixed per-level cost of radix -/// sort (256-bucket histogram + scatter) exceeds the O(n log n) cost of -/// comparison sort with small n and warm cache lines. -const FALLBACK_THRESHOLD: usize = 64; - -/// Beyond this depth, comparison sort on the full row handles the -/// remaining discrimination. 8 bytes covers the discriminating prefix -/// of most key layouts; deeper recursion hits diminishing returns as -/// buckets become sparse and the per-level overhead dominates. +/// Buckets smaller than this fall back to comparison sort. A lower +/// threshold favors the comparison path, which avoids the ping-pong +/// buffer overhead and per-level indirection cost of radix passes on +/// small buckets where O(n log n) comparison sort is already cheap. +const FALLBACK_THRESHOLD: usize = 32; + +/// Maximum number of radix passes before falling back to comparison +/// sort. Each pass chases pointers through the Rows offset/buffer +/// indirection, so deeper passes hit diminishing returns as buckets +/// shrink and the per-level overhead dominates. 8 bytes covers the +/// discriminating prefix of most key layouts including skewed or +/// narrow-range distributions; remaining ties are resolved by +/// comparison sort on the suffix. const MAX_DEPTH: usize = 8; /// Sort row indices using MSD radix sort on row-encoded keys. @@ -90,16 +94,46 @@ const MAX_DEPTH: usize = 8; /// [`take`]: https://docs.rs/arrow/latest/arrow/compute/fn.take.html /// [`lexsort_to_indices`]: https://docs.rs/arrow-ord/latest/arrow_ord/sort/fn.lexsort_to_indices.html pub fn radix_sort_to_indices(rows: &Rows) -> Vec { + radix_sort_to_indices_with(rows, MAX_DEPTH, FALLBACK_THRESHOLD) +} + +/// Like [`radix_sort_to_indices`] but with tunable parameters for +/// benchmarking. `max_depth` controls how many radix passes to run +/// before falling back to comparison sort. `fallback_threshold` is +/// the bucket size below which we switch to comparison sort. +pub fn radix_sort_to_indices_with( + rows: &Rows, + max_depth: usize, + fallback_threshold: usize, +) -> Vec { let n = rows.num_rows(); let mut indices: Vec = (0..n as u32).collect(); let mut temp = vec![0u32; n]; let mut bytes = vec![0u8; n]; - msd_radix_sort(&mut indices, &mut temp, &mut bytes, rows, 0, true); + let config = RadixSortConfig { + max_depth, + fallback_threshold, + }; + msd_radix_sort(&mut indices, &mut temp, &mut bytes, rows, 0, true, &config); indices } -/// Returns the byte at `byte_pos` for the row at `idx`, or 0 if the row -/// is shorter. +/// Tunable parameters for MSD radix sort. +struct RadixSortConfig { + /// Maximum number of radix passes before falling back to comparison sort. + max_depth: usize, + /// Buckets smaller than this fall back to comparison sort. + fallback_threshold: usize, +} + +impl Default for RadixSortConfig { + fn default() -> Self { + Self { + max_depth: MAX_DEPTH, + fallback_threshold: FALLBACK_THRESHOLD, + } + } +} /// /// # Safety /// `idx` must be a valid row index in `rows`. @@ -126,10 +160,11 @@ fn msd_radix_sort( rows: &Rows, byte_pos: usize, result_in_src: bool, + config: &RadixSortConfig, ) { let n = src.len(); - if n <= FALLBACK_THRESHOLD || byte_pos >= MAX_DEPTH { + if n <= config.fallback_threshold || byte_pos >= config.max_depth { // Compare only from byte_pos onward — earlier bytes are identical // within this bucket, having already been discriminated by radix // passes above us. Safe slice via get() is needed because rows of @@ -178,7 +213,7 @@ fn msd_radix_sort( // No scatter happened — data is still in src, roles unchanged. if num_buckets == 1 { - msd_radix_sort(src, dst, bytes, rows, byte_pos + 1, result_in_src); + msd_radix_sort(src, dst, bytes, rows, byte_pos + 1, result_in_src, config); return; } @@ -205,6 +240,7 @@ fn msd_radix_sort( rows, byte_pos + 1, !result_in_src, + config, ); } else if len == 1 && result_in_src { // Single-element bucket doesn't recurse. After scatter From c9792cc20555935f1862455f5d5980a66e066497 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 14 Apr 2026 09:43:57 -0400 Subject: [PATCH 10/11] Add more benchmark cases. --- arrow/benches/lexsort.rs | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/arrow/benches/lexsort.rs b/arrow/benches/lexsort.rs index b7adc59fa6b1..50c00684f9de 100644 --- a/arrow/benches/lexsort.rs +++ b/arrow/benches/lexsort.rs @@ -22,7 +22,7 @@ use arrow::util::bench_util::{ create_dict_from_values, create_primitive_array, create_string_array_with_len, }; use arrow::util::data_gen::create_random_array; -use arrow_array::types::Int32Type; +use arrow_array::types::{Float64Type, Int32Type, Int64Type}; use arrow_array::{Array, ArrayRef, UInt32Array}; use arrow_schema::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; @@ -32,6 +32,10 @@ use std::{hint, sync::Arc}; enum Column { RequiredI32, OptionalI32, + RequiredI64, + OptionalI64, + RequiredF64, + OptionalF64, Required16CharString, Optional16CharString, Optional50CharString, @@ -47,6 +51,10 @@ impl std::fmt::Debug for Column { let s = match self { Column::RequiredI32 => "i32", Column::OptionalI32 => "i32_opt", + Column::RequiredI64 => "i64", + Column::OptionalI64 => "i64_opt", + Column::RequiredF64 => "f64", + Column::OptionalF64 => "f64_opt", Column::Required16CharString => "str(16)", Column::Optional16CharString => "str_opt(16)", Column::Optional50CharString => "str_opt(50)", @@ -65,6 +73,10 @@ impl Column { match self { Column::RequiredI32 => Arc::new(create_primitive_array::(size, 0.)), Column::OptionalI32 => Arc::new(create_primitive_array::(size, 0.2)), + Column::RequiredI64 => Arc::new(create_primitive_array::(size, 0.)), + Column::OptionalI64 => Arc::new(create_primitive_array::(size, 0.2)), + Column::RequiredF64 => Arc::new(create_primitive_array::(size, 0.)), + Column::OptionalF64 => Arc::new(create_primitive_array::(size, 0.2)), Column::Required16CharString => { Arc::new(create_string_array_with_len::(size, 0., 16)) } @@ -166,6 +178,17 @@ fn do_bench(c: &mut Criterion, columns: &[Column], len: usize) { fn add_benchmark(c: &mut Criterion) { let cases: &[&[Column]] = &[ + // Single-column primitives + &[Column::RequiredI32], + &[Column::OptionalI32], + &[Column::RequiredI64], + &[Column::OptionalI64], + &[Column::RequiredF64], + &[Column::OptionalF64], + // Single-column strings + &[Column::Required16CharString], + &[Column::Optional16CharString], + // Multi-column &[Column::RequiredI32, Column::OptionalI32], &[Column::RequiredI32, Column::Optional16CharString], &[Column::RequiredI32, Column::Required16CharString], From d8e323c821c9aa4b49e116214aedf03f6a128523 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 14 Apr 2026 09:49:20 -0400 Subject: [PATCH 11/11] Update docs based on DataFusion TPC-H observations. --- arrow-row/src/radix.rs | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/arrow-row/src/radix.rs b/arrow-row/src/radix.rs index aa8c15fee636..7262ea814169 100644 --- a/arrow-row/src/radix.rs +++ b/arrow-row/src/radix.rs @@ -24,23 +24,32 @@ //! //! # When to use this //! -//! Radix sort on row-encoded keys is the fastest sort strategy for most -//! multi-column sorts, including: -//! - **Primitive columns** (integers, floats) -//! - **String columns**, especially multiple string columns +//! Radix sort is the fastest strategy when sorting by **two or more columns**, +//! especially as N grows. It benefits from: +//! - **Multi-column schemas** where comparison sort must traverse columns +//! per comparison, while radix sort pays a fixed cost per byte position +//! - **String columns**, where the row encoding produces compact, +//! high-entropy byte sequences that radix passes discriminate quickly //! - **Mixed column types** (primitives, strings, dicts, lists) //! -//! The advantage over [`lexsort_to_indices`] grows with N and with the -//! number of columns. -//! //! # When NOT to use this //! //! Prefer [`lexsort_to_indices`] when: +//! - **Sorting by a single column.** The row encoding overhead (allocation, +//! encoding, indirection through `Rows`) outweighs the radix advantage. +//! Single-column sorts are faster with direct comparison sort on the +//! columnar array, which avoids encoding entirely. //! - **All sort columns are low-cardinality dictionaries** with no //! high-cardinality column to break ties. The row encoding for //! dictionary values produces long shared prefixes, and radix sort //! gains little from its first few byte passes before falling back //! to comparison sort. +//! - **Columns with low-entropy leading bytes**, such as `Decimal128` or +//! `Decimal256`. These types are encoded as 16- or 32-byte big-endian +//! integers, but real-world values occupy a tiny fraction of the range. +//! The leading bytes are nearly identical across rows (e.g., `0x80` for +//! small positive values), so radix passes burn through the max depth +//! without discriminating rows, then fall back to comparison sort. //! - **A leading primitive column discriminates most rows and a trailing //! column is expensive to encode** (e.g., lists). [`lexsort_to_indices`] //! avoids encoding the trailing column for rows already resolved by