Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 46 additions & 23 deletions datafusion/core/src/physical_plan/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,52 +426,75 @@ impl Iterator for SortedIterator {
// Combine adjacent indexes from the same batch to make a slice,
// for more efficient `extend` later.
let mut last_batch_idx = 0;
let mut start_row_idx = 0;
let mut len = 0;
let mut indices_in_batch = vec![];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why using a Vec here is better. Aren't the indices a contiguous range?

I wonder what about using Option<std::ops::Range> instead https://doc.rust-lang.org/std/ops/struct.Range.html ?

Maybe as a follow on PR


let mut slices = vec![];
for i in 0..current_size {
let p = self.pos + i;
let c_index = self.indices.value(p) as usize;
let ci = self.composite[c_index];

if len == 0 {
if indices_in_batch.is_empty() {
last_batch_idx = ci.batch_idx;
start_row_idx = ci.row_idx;
len = 1;
indices_in_batch.push(ci.row_idx);
} else if ci.batch_idx == last_batch_idx {
len += 1;
// since we have pre-sort each of the incoming batches,
// so if we witnessed a wrong order of indexes from the same batch,
// it must be of the same key with the row pointed by start_row_index.
start_row_idx = min(start_row_idx, ci.row_idx);
indices_in_batch.push(ci.row_idx);
} else {
slices.push(CompositeSlice {
batch_idx: last_batch_idx,
start_row_idx,
len,
});
group_indices(last_batch_idx, &mut indices_in_batch, &mut slices);
last_batch_idx = ci.batch_idx;
start_row_idx = ci.row_idx;
len = 1;
Copy link
Contributor

@alamb alamb Apr 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the bug -- that len should be reset to 0 rather than 1?

Copy link
Member Author

@yjshen yjshen Apr 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, as described in the PR description, the bug comes from non continuous indexes, which is introduced by unstable lexsort. So it's possible we will see several disjoint ranges comes from one batch. The gap between the ranges are of same key but moved by unstable sort

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see 👍 -- that is a tricky one . So sad

indices_in_batch.push(ci.row_idx);
}
}

assert!(
len > 0,
!indices_in_batch.is_empty(),
"There should have at least one record in a sort output slice."
);
slices.push(CompositeSlice {
batch_idx: last_batch_idx,
start_row_idx,
len,
});
group_indices(last_batch_idx, &mut indices_in_batch, &mut slices);

self.pos += current_size;
Some(slices)
}
}

/// Group continuous indices into a slice for better `extend` performance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand how positions can be non contiguous (doesn't this function get called each time batch_idx changes to a new batch)?

fn group_indices(
batch_idx: u32,
positions: &mut Vec<u32>,
output: &mut Vec<CompositeSlice>,
) {
positions.sort_unstable();
let mut last_pos = 0;
let mut run_length = 0;
for pos in positions.iter() {
if run_length == 0 {
last_pos = *pos;
run_length = 1;
} else if *pos == last_pos + 1 {
run_length += 1;
last_pos = *pos;
} else {
output.push(CompositeSlice {
batch_idx,
start_row_idx: last_pos + 1 - run_length,
len: run_length as usize,
});
last_pos = *pos;
run_length = 1;
}
}
assert!(
run_length > 0,
"There should have at least one record in a sort output slice."
);
output.push(CompositeSlice {
batch_idx,
start_row_idx: last_pos + 1 - run_length,
len: run_length as usize,
});
positions.clear()
}

/// Stream of sorted record batches
struct SortedSizedRecordBatchStream {
schema: SchemaRef,
Expand Down
Binary file not shown.
23 changes: 23 additions & 0 deletions datafusion/core/tests/sql/order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use super::*;
use fuzz_utils::{batches_to_vec, partitions_to_sorted_vec};

#[tokio::test]
async fn test_sort_unprojected_col() -> Result<()> {
Expand Down Expand Up @@ -198,3 +199,25 @@ async fn sort_empty() -> Result<()> {
assert_eq!(results.len(), 0);
Ok(())
}

#[tokio::test]
async fn sort_with_lots_of_repetition_values() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️ love the tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I ran the test locally without the changes in this PR to confirm coverage:

cargo test -p datafusion --test sql_integration -- sort_with_lots_of_repetition_values

They failed with:


---- sql::order::sort_with_lots_of_repetition_values stdout ----
thread 'sql::order::sort_with_lots_of_repetition_values' panicked at 'assertion failed: `(left == right)`
  left: `Some(2451809)`,
 right: `Some(2451816)`', datafusion/core/tests/sql/order.rs:220:9
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

👍

let ctx = SessionContext::new();
let filename = "tests/parquet/repeat_much.snappy.parquet";

ctx.register_parquet("rep", filename, ParquetReadOptions::default())
.await?;
let sql = "select a from rep order by a";
let actual = execute_to_batches(&ctx, sql).await;
let actual = batches_to_vec(&actual);

let sql1 = "select a from rep";
let expected = execute_to_batches(&ctx, sql1).await;
let expected = partitions_to_sorted_vec(&[expected]);

assert_eq!(actual.len(), expected.len());
for i in 0..actual.len() {
assert_eq!(actual[i], expected[i]);
}
Ok(())
}