Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion arrow-ord/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,39 @@ fn partition_validity(array: &dyn Array) -> (Vec<u32>, Vec<u32>) {
}
}

/// Whether `arrow_ord::rank` can rank an array of given data type.
fn can_rank(data_type: &DataType) -> bool {
data_type.is_primitive()
|| matches!(
data_type,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary
)
}

/// Whether `sort_to_indices` can sort an array of given data type.
fn can_sort_to_indices(data_type: &DataType) -> bool {
data_type.is_primitive()
|| matches!(
data_type,
DataType::Boolean
| DataType::Utf8
| DataType::LargeUtf8
| DataType::Utf8View
| DataType::Binary
| DataType::LargeBinary
| DataType::BinaryView
| DataType::FixedSizeBinary(_)
)
|| match data_type {
DataType::List(f) if can_rank(f.data_type()) => true,
DataType::LargeList(f) if can_rank(f.data_type()) => true,
DataType::FixedSizeList(f, _) if can_rank(f.data_type()) => true,
DataType::Dictionary(_, values) if can_rank(values.as_ref()) => true,
DataType::RunEndEncoded(_, f) if can_sort_to_indices(f.data_type()) => true,
_ => false,
}
}

/// Sort elements from `ArrayRef` into an unsigned integer (`UInt32Array`) of indices.
/// Floats are sorted using IEEE 754 totalOrder. `limit` is an option for [partial_sort].
pub fn sort_to_indices(
Expand Down Expand Up @@ -678,7 +711,7 @@ pub fn lexsort_to_indices(
"Sort requires at least one column".to_string(),
));
}
if columns.len() == 1 {
if columns.len() == 1 && can_sort_to_indices(columns[0].values.data_type()) {
// fallback to non-lexical sort
let column = &columns[0];
return sort_to_indices(&column.values, column.options, limit);
Expand Down Expand Up @@ -762,6 +795,7 @@ mod tests {
FixedSizeListBuilder, Int64Builder, ListBuilder, PrimitiveRunBuilder,
};
use arrow_buffer::{i256, NullBuffer};
use arrow_schema::Field;
use half::f16;
use rand::rngs::StdRng;
use rand::{Rng, RngCore, SeedableRng};
Expand Down Expand Up @@ -4203,4 +4237,47 @@ mod tests {
let sort_indices = sort_to_indices(&a, None, None).unwrap();
assert_eq!(sort_indices.values(), &[1, 2, 0]);
}

#[test]
fn sort_struct_fallback_to_lexsort() {
let float = Arc::new(Float32Array::from(vec![1.0, -0.1, 3.5, 1.0]));
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));

let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("b", DataType::Float32, false)),
float.clone() as ArrayRef,
),
(
Arc::new(Field::new("c", DataType::Int32, false)),
int.clone() as ArrayRef,
),
]);

assert!(!can_sort_to_indices(struct_array.data_type()));
assert!(sort_to_indices(&struct_array, None, None)
.err()
.unwrap()
.to_string()
.contains("Sort not supported for data type"));

let sort_columns = vec![SortColumn {
values: Arc::new(struct_array.clone()) as ArrayRef,
options: None,
}];
let sorted = lexsort(&sort_columns, None).unwrap();

let expected_struct_array = Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("b", DataType::Float32, false)),
Arc::new(Float32Array::from(vec![-0.1, 1.0, 1.0, 3.5])) as ArrayRef,
),
(
Arc::new(Field::new("c", DataType::Int32, false)),
Arc::new(Int32Array::from(vec![28, 31, 42, 19])) as ArrayRef,
),
])) as ArrayRef;

assert_eq!(&sorted[0], &expected_struct_array);
}
}