Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 233 additions & 33 deletions rust/lance/src/dataset/write/merge_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ use crate::{
};
use arrow_array::{
cast::AsArray, types::UInt64Type, BooleanArray, RecordBatch, RecordBatchIterator, StructArray,
UInt64Array,
UInt32Array, UInt64Array,
};
use arrow_schema::{DataType, Field, Schema};
use arrow_select::take::take_record_batch;
use datafusion::common::NullEquality;
use datafusion::error::DataFusionError;
use datafusion::{
Expand Down Expand Up @@ -287,6 +288,19 @@ pub enum WhenNotMatched {
DoNothing,
}

/// Describes how to handle duplicate source rows that match the same target row.
///
/// If the source contains duplicates and `FirstSeen` behavior doesn't match your needs,
/// sort the source data before passing it to the merge insert operation.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
pub enum SourceDedupeBehavior {
/// Fail the operation if duplicates are found (default)
#[default]
Fail,
/// Keep the first seen value and skip subsequent duplicates
FirstSeen,
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
struct MergeInsertParams {
// The column(s) to join on
Expand All @@ -310,6 +324,8 @@ struct MergeInsertParams {
// Controls whether to use indices for the merge operation. Default is true.
// Setting to false forces a full table scan even if an index exists.
use_index: bool,
// Controls how to handle duplicate source rows that match the same target row.
source_dedupe_behavior: SourceDedupeBehavior,
}

/// A MergeInsertJob inserts new rows, deletes old rows, and updates existing rows all as
Expand Down Expand Up @@ -413,6 +429,7 @@ impl MergeInsertBuilder {
mem_wal_to_merge: None,
skip_auto_cleanup: false,
use_index: true,
source_dedupe_behavior: SourceDedupeBehavior::Fail,
},
})
}
Expand Down Expand Up @@ -484,6 +501,18 @@ impl MergeInsertBuilder {
self
}

/// Specify how to handle duplicate source rows that match the same target row.
///
/// Default is `Fail` which errors on duplicates.
/// Use `FirstSeen` to keep the first encountered row and skip duplicates.
///
/// If the source contains duplicates and `FirstSeen` behavior doesn't match your needs,
/// sort the source data before passing it to the merge insert operation.
pub fn source_dedupe_behavior(&mut self, behavior: SourceDedupeBehavior) -> &mut Self {
self.params.source_dedupe_behavior = behavior;
self
}

/// Indicate that this merge-insert uses data in a flushed MemTable.
/// Once write is completed, the corresponding MemTable should also be marked as merged.
pub async fn mark_mem_wal_as_merged(
Expand Down Expand Up @@ -1822,6 +1851,8 @@ pub struct MergeStats {
pub bytes_written: u64,
/// Number of data files written. This currently only includes data files.
pub num_files_written: u64,
/// Number of duplicate source rows skipped (when SourceDedupeBehavior::FirstSeen)
pub num_skipped_duplicates: u64,
}

pub struct UncommittedMergeInsert {
Expand Down Expand Up @@ -2079,44 +2110,69 @@ impl Merger {
let row_ids = matched.column(row_id_col).as_primitive::<UInt64Type>();

let mut processed_row_ids = self.processed_row_ids.lock().unwrap();
let mut keep_indices: Vec<u32> = Vec::with_capacity(matched.num_rows());
for (row_idx, &row_id) in row_ids.values().iter().enumerate() {
if !processed_row_ids.insert(row_id) {
return Err(create_duplicate_row_error(
&matched,
row_idx,
&self.params.on,
));
if processed_row_ids.insert(row_id) {
keep_indices.push(row_idx as u32);
} else {
match self.params.source_dedupe_behavior {
SourceDedupeBehavior::Fail => {
return Err(create_duplicate_row_error(
&matched,
row_idx,
&self.params.on,
));
}
SourceDedupeBehavior::FirstSeen => {
// Skip this duplicate row (don't add to keep_indices)
}
}
}
}
drop(processed_row_ids);

deleted_row_ids.extend(row_ids.values());
if self.enable_stable_row_ids {
self.updating_row_ids
.lock()
.unwrap()
.capture(row_ids.values())?;
// Filter out duplicate rows if any were skipped
let num_skipped = matched.num_rows() - keep_indices.len();
if num_skipped > 0 {
merge_statistics.num_skipped_duplicates += num_skipped as u64;
merge_statistics.num_updated_rows -= num_skipped as u64;

let indices = UInt32Array::from(keep_indices);
matched = take_record_batch(&matched, &indices)?;
}

let projection = if let Some(row_addr_col) = row_addr_col {
let mut cols = Vec::from_iter(left_cols.iter().cloned());
cols.push(row_addr_col);
cols
} else {
#[allow(clippy::redundant_clone)]
left_cols.clone()
};
let matched = matched.project(&projection)?;
// The payload columns of an outer join are always nullable. We need to restore
// non-nullable to columns that were originally non-nullable. This should be safe
// since the not_matched rows should all be valid on the right_cols
//
// Sadly we can't use with_schema because it doesn't let you toggle nullability
let matched = RecordBatch::try_new(
self.output_schema.clone(),
Vec::from_iter(matched.columns().iter().cloned()),
)?;
batches.push(Ok(matched));
// Only process and write if there are remaining rows after filtering duplicates
if matched.num_rows() > 0 {
// Get row_ids again after filtering (if any duplicates were removed)
let row_ids = matched.column(row_id_col).as_primitive::<UInt64Type>();
deleted_row_ids.extend(row_ids.values());
if self.enable_stable_row_ids {
self.updating_row_ids
.lock()
.unwrap()
.capture(row_ids.values())?;
}

let projection = if let Some(row_addr_col) = row_addr_col {
let mut cols = Vec::from_iter(left_cols.iter().cloned());
cols.push(row_addr_col);
cols
} else {
#[allow(clippy::redundant_clone)]
left_cols.clone()
};
let matched = matched.project(&projection)?;
// The payload columns of an outer join are always nullable. We need to restore
// non-nullable to columns that were originally non-nullable. This should be safe
// since the not_matched rows should all be valid on the right_cols
//
// Sadly we can't use with_schema because it doesn't let you toggle nullability
let matched = RecordBatch::try_new(
self.output_schema.clone(),
Vec::from_iter(matched.columns().iter().cloned()),
)?;
batches.push(Ok(matched));
}
}
}
if self.params.insert_not_matched {
Expand Down Expand Up @@ -4818,7 +4874,7 @@ MergeInsert: on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_n
);

// Also validate the full string structure with pattern matching
let expected_pattern = "[...MergeInsert: elapsed=..., on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep, metrics=...bytes_written=...num_deleted_rows=0, num_files_written=...num_inserted_rows=1, num_updated_rows=1]
let expected_pattern = "[...MergeInsert: elapsed=..., on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep, metrics=...bytes_written=...num_deleted_rows=0, num_files_written=...num_inserted_rows=1, num_skipped_duplicates=0, num_updated_rows=1]
...
StreamingTableExec: partition_sizes=1, projection=[id, name], metrics=[]...]";
assert_string_matches(&analysis, expected_pattern).unwrap();
Expand Down Expand Up @@ -5012,6 +5068,150 @@ MergeInsert: on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_n
);
}

#[tokio::test]
#[rstest::rstest]
async fn test_source_dedupe_behavior_first_seen(
#[values(false, true)] is_full_schema: bool,
#[values(true, false)] enable_stable_row_ids: bool,
#[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)]
data_storage_version: LanceFileVersion,
) {
let test_uri = format!(
"memory://test_dedupe_first_seen_{}_{}.lance",
is_full_schema, enable_stable_row_ids
);

// Create initial dataset with keys 1, 2, 3, 4
let dataset = lance_datagen::gen_batch()
.col("key", array::step_custom::<UInt32Type>(1, 1))
.col("value", array::step_custom::<UInt32Type>(10, 10))
.into_dataset_with_params(
&test_uri,
FragmentCount(1),
FragmentRowCount(4),
Some(WriteParams {
max_rows_per_file: 4,
enable_stable_row_ids,
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await
.unwrap();

// Initial data: key=1,value=10; key=2,value=20; key=3,value=30; key=4,value=40
let initial_data: Vec<(u32, u32)> = dataset
.scan()
.try_into_batch()
.await
.unwrap()
.columns()
.iter()
.map(|c| c.as_primitive::<UInt32Type>().values().to_vec())
.collect::<Vec<_>>()
.into_iter()
.fold(Vec::new(), |mut acc, vals| {
if acc.is_empty() {
acc = vals.into_iter().map(|v| (v, 0)).collect();
} else {
for (i, v) in vals.into_iter().enumerate() {
acc[i].1 = v;
}
}
acc
});
assert_eq!(
initial_data,
vec![(1, 10), (2, 20), (3, 30), (4, 40)],
"Initial data should be correct"
);

let schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::UInt32, is_full_schema),
Field::new("value", DataType::UInt32, is_full_schema),
]));

// Source data with duplicates:
// - key=2 appears 3 times with values 100, 200, 300 (first seen: 100)
// - key=3 appears 2 times with values 400, 500 (first seen: 400)
// - key=5 is a new insert (value=600)
// Total duplicates: 3 (2 extra for key=2, 1 extra for key=3)
let source_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![2, 2, 2, 3, 3, 5])),
Arc::new(UInt32Array::from(vec![100, 200, 300, 400, 500, 600])),
],
)
.unwrap();

let job = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()])
.unwrap()
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll)
.source_dedupe_behavior(SourceDedupeBehavior::FirstSeen)
.try_build()
.unwrap();

let reader = Box::new(RecordBatchIterator::new([Ok(source_batch)], schema.clone()));
let stream = reader_to_stream(reader);

let (dataset, stats) = job.execute(stream).await.unwrap();

// Verify stats
assert_eq!(
stats.num_skipped_duplicates, 3,
"Should have skipped 3 duplicate rows (2 extra for key=2, 1 extra for key=3)"
);
assert_eq!(
stats.num_updated_rows, 2,
"Should have updated 2 rows (key=2 and key=3)"
);
assert_eq!(
stats.num_inserted_rows, 1,
"Should have inserted 1 row (key=5)"
);

// Verify the actual data - first seen values should be kept
let result_batch = dataset.scan().try_into_batch().await.unwrap();
let keys = result_batch.column(0).as_primitive::<UInt32Type>();
let values = result_batch.column(1).as_primitive::<UInt32Type>();

let result_data: std::collections::HashMap<u32, u32> = keys
.values()
.iter()
.zip(values.values().iter())
.map(|(&k, &v)| (k, v))
.collect();

assert_eq!(result_data.len(), 5, "Should have 5 rows total");
assert_eq!(
result_data.get(&1),
Some(&10),
"key=1 should be unchanged (original value)"
);
assert_eq!(
result_data.get(&2),
Some(&100),
"key=2 should have first seen value (100, not 200 or 300)"
);
assert_eq!(
result_data.get(&3),
Some(&400),
"key=3 should have first seen value (400, not 500)"
);
assert_eq!(
result_data.get(&4),
Some(&40),
"key=4 should be unchanged (original value)"
);
assert_eq!(
result_data.get(&5),
Some(&600),
"key=5 should be inserted with value 600"
);
}

#[tokio::test]
async fn test_merge_insert_use_index() {
let data = lance_datagen::gen_batch()
Expand Down
5 changes: 5 additions & 0 deletions rust/lance/src/dataset/write/merge_insert/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub(super) struct MergeInsertMetrics {
pub num_deleted_rows: Count,
pub bytes_written: Count,
pub num_files_written: Count,
pub num_skipped_duplicates: Count,
}

impl From<&MergeInsertMetrics> for MergeStats {
Expand All @@ -34,6 +35,7 @@ impl From<&MergeInsertMetrics> for MergeStats {
num_updated_rows: value.num_updated_rows.value() as u64,
bytes_written: value.bytes_written.value() as u64,
num_files_written: value.num_files_written.value() as u64,
num_skipped_duplicates: value.num_skipped_duplicates.value() as u64,
num_attempts: 1,
}
}
Expand All @@ -46,12 +48,15 @@ impl MergeInsertMetrics {
let num_deleted_rows = MetricBuilder::new(metrics).counter("num_deleted_rows", partition);
let bytes_written = MetricBuilder::new(metrics).counter("bytes_written", partition);
let num_files_written = MetricBuilder::new(metrics).counter("num_files_written", partition);
let num_skipped_duplicates =
MetricBuilder::new(metrics).counter("num_skipped_duplicates", partition);
Self {
num_inserted_rows,
num_updated_rows,
num_deleted_rows,
bytes_written,
num_files_written,
num_skipped_duplicates,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions rust/lance/src/dataset/write/merge_insert/exec/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ impl ExecutionPlan for DeleteOnlyMergeInsertExec {
bytes_written: 0,
num_files_written: 0,
num_attempts: 1,
num_skipped_duplicates: 0,
};

if let Ok(mut transaction_guard) = transaction_holder.lock() {
Expand Down
Loading
Loading