diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index 12df92f534f..efbcdff92c8 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -44,9 +44,10 @@ use crate::{ }; use arrow_array::{ cast::AsArray, types::UInt64Type, BooleanArray, RecordBatch, RecordBatchIterator, StructArray, - UInt64Array, + UInt32Array, UInt64Array, }; use arrow_schema::{DataType, Field, Schema}; +use arrow_select::take::take_record_batch; use datafusion::common::NullEquality; use datafusion::error::DataFusionError; use datafusion::{ @@ -287,6 +288,19 @@ pub enum WhenNotMatched { DoNothing, } +/// Describes how to handle duplicate source rows that match the same target row. +/// +/// If the source contains duplicates and `FirstSeen` behavior doesn't match your needs, +/// sort the source data before passing it to the merge insert operation. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)] +pub enum SourceDedupeBehavior { + /// Fail the operation if duplicates are found (default) + #[default] + Fail, + /// Keep the first seen value and skip subsequent duplicates + FirstSeen, +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] struct MergeInsertParams { // The column(s) to join on @@ -310,6 +324,8 @@ struct MergeInsertParams { // Controls whether to use indices for the merge operation. Default is true. // Setting to false forces a full table scan even if an index exists. use_index: bool, + // Controls how to handle duplicate source rows that match the same target row. + source_dedupe_behavior: SourceDedupeBehavior, } /// A MergeInsertJob inserts new rows, deletes old rows, and updates existing rows all as @@ -413,6 +429,7 @@ impl MergeInsertBuilder { mem_wal_to_merge: None, skip_auto_cleanup: false, use_index: true, + source_dedupe_behavior: SourceDedupeBehavior::Fail, }, }) } @@ -484,6 +501,18 @@ impl MergeInsertBuilder { self } + /// Specify how to handle duplicate source rows that match the same target row. + /// + /// Default is `Fail` which errors on duplicates. + /// Use `FirstSeen` to keep the first encountered row and skip duplicates. + /// + /// If the source contains duplicates and `FirstSeen` behavior doesn't match your needs, + /// sort the source data before passing it to the merge insert operation. + pub fn source_dedupe_behavior(&mut self, behavior: SourceDedupeBehavior) -> &mut Self { + self.params.source_dedupe_behavior = behavior; + self + } + /// Indicate that this merge-insert uses data in a flushed MemTable. /// Once write is completed, the corresponding MemTable should also be marked as merged. pub async fn mark_mem_wal_as_merged( @@ -1822,6 +1851,8 @@ pub struct MergeStats { pub bytes_written: u64, /// Number of data files written. This currently only includes data files. pub num_files_written: u64, + /// Number of duplicate source rows skipped (when SourceDedupeBehavior::FirstSeen) + pub num_skipped_duplicates: u64, } pub struct UncommittedMergeInsert { @@ -2079,44 +2110,69 @@ impl Merger { let row_ids = matched.column(row_id_col).as_primitive::(); let mut processed_row_ids = self.processed_row_ids.lock().unwrap(); + let mut keep_indices: Vec = Vec::with_capacity(matched.num_rows()); for (row_idx, &row_id) in row_ids.values().iter().enumerate() { - if !processed_row_ids.insert(row_id) { - return Err(create_duplicate_row_error( - &matched, - row_idx, - &self.params.on, - )); + if processed_row_ids.insert(row_id) { + keep_indices.push(row_idx as u32); + } else { + match self.params.source_dedupe_behavior { + SourceDedupeBehavior::Fail => { + return Err(create_duplicate_row_error( + &matched, + row_idx, + &self.params.on, + )); + } + SourceDedupeBehavior::FirstSeen => { + // Skip this duplicate row (don't add to keep_indices) + } + } } } drop(processed_row_ids); - deleted_row_ids.extend(row_ids.values()); - if self.enable_stable_row_ids { - self.updating_row_ids - .lock() - .unwrap() - .capture(row_ids.values())?; + // Filter out duplicate rows if any were skipped + let num_skipped = matched.num_rows() - keep_indices.len(); + if num_skipped > 0 { + merge_statistics.num_skipped_duplicates += num_skipped as u64; + merge_statistics.num_updated_rows -= num_skipped as u64; + + let indices = UInt32Array::from(keep_indices); + matched = take_record_batch(&matched, &indices)?; } - let projection = if let Some(row_addr_col) = row_addr_col { - let mut cols = Vec::from_iter(left_cols.iter().cloned()); - cols.push(row_addr_col); - cols - } else { - #[allow(clippy::redundant_clone)] - left_cols.clone() - }; - let matched = matched.project(&projection)?; - // The payload columns of an outer join are always nullable. We need to restore - // non-nullable to columns that were originally non-nullable. This should be safe - // since the not_matched rows should all be valid on the right_cols - // - // Sadly we can't use with_schema because it doesn't let you toggle nullability - let matched = RecordBatch::try_new( - self.output_schema.clone(), - Vec::from_iter(matched.columns().iter().cloned()), - )?; - batches.push(Ok(matched)); + // Only process and write if there are remaining rows after filtering duplicates + if matched.num_rows() > 0 { + // Get row_ids again after filtering (if any duplicates were removed) + let row_ids = matched.column(row_id_col).as_primitive::(); + deleted_row_ids.extend(row_ids.values()); + if self.enable_stable_row_ids { + self.updating_row_ids + .lock() + .unwrap() + .capture(row_ids.values())?; + } + + let projection = if let Some(row_addr_col) = row_addr_col { + let mut cols = Vec::from_iter(left_cols.iter().cloned()); + cols.push(row_addr_col); + cols + } else { + #[allow(clippy::redundant_clone)] + left_cols.clone() + }; + let matched = matched.project(&projection)?; + // The payload columns of an outer join are always nullable. We need to restore + // non-nullable to columns that were originally non-nullable. This should be safe + // since the not_matched rows should all be valid on the right_cols + // + // Sadly we can't use with_schema because it doesn't let you toggle nullability + let matched = RecordBatch::try_new( + self.output_schema.clone(), + Vec::from_iter(matched.columns().iter().cloned()), + )?; + batches.push(Ok(matched)); + } } } if self.params.insert_not_matched { @@ -4818,7 +4874,7 @@ MergeInsert: on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_n ); // Also validate the full string structure with pattern matching - let expected_pattern = "[...MergeInsert: elapsed=..., on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep, metrics=...bytes_written=...num_deleted_rows=0, num_files_written=...num_inserted_rows=1, num_updated_rows=1] + let expected_pattern = "[...MergeInsert: elapsed=..., on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep, metrics=...bytes_written=...num_deleted_rows=0, num_files_written=...num_inserted_rows=1, num_skipped_duplicates=0, num_updated_rows=1] ... StreamingTableExec: partition_sizes=1, projection=[id, name], metrics=[]...]"; assert_string_matches(&analysis, expected_pattern).unwrap(); @@ -5012,6 +5068,150 @@ MergeInsert: on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_n ); } + #[tokio::test] + #[rstest::rstest] + async fn test_source_dedupe_behavior_first_seen( + #[values(false, true)] is_full_schema: bool, + #[values(true, false)] enable_stable_row_ids: bool, + #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] + data_storage_version: LanceFileVersion, + ) { + let test_uri = format!( + "memory://test_dedupe_first_seen_{}_{}.lance", + is_full_schema, enable_stable_row_ids + ); + + // Create initial dataset with keys 1, 2, 3, 4 + let dataset = lance_datagen::gen_batch() + .col("key", array::step_custom::(1, 1)) + .col("value", array::step_custom::(10, 10)) + .into_dataset_with_params( + &test_uri, + FragmentCount(1), + FragmentRowCount(4), + Some(WriteParams { + max_rows_per_file: 4, + enable_stable_row_ids, + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await + .unwrap(); + + // Initial data: key=1,value=10; key=2,value=20; key=3,value=30; key=4,value=40 + let initial_data: Vec<(u32, u32)> = dataset + .scan() + .try_into_batch() + .await + .unwrap() + .columns() + .iter() + .map(|c| c.as_primitive::().values().to_vec()) + .collect::>() + .into_iter() + .fold(Vec::new(), |mut acc, vals| { + if acc.is_empty() { + acc = vals.into_iter().map(|v| (v, 0)).collect(); + } else { + for (i, v) in vals.into_iter().enumerate() { + acc[i].1 = v; + } + } + acc + }); + assert_eq!( + initial_data, + vec![(1, 10), (2, 20), (3, 30), (4, 40)], + "Initial data should be correct" + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, is_full_schema), + Field::new("value", DataType::UInt32, is_full_schema), + ])); + + // Source data with duplicates: + // - key=2 appears 3 times with values 100, 200, 300 (first seen: 100) + // - key=3 appears 2 times with values 400, 500 (first seen: 400) + // - key=5 is a new insert (value=600) + // Total duplicates: 3 (2 extra for key=2, 1 extra for key=3) + let source_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![2, 2, 2, 3, 3, 5])), + Arc::new(UInt32Array::from(vec![100, 200, 300, 400, 500, 600])), + ], + ) + .unwrap(); + + let job = MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .source_dedupe_behavior(SourceDedupeBehavior::FirstSeen) + .try_build() + .unwrap(); + + let reader = Box::new(RecordBatchIterator::new([Ok(source_batch)], schema.clone())); + let stream = reader_to_stream(reader); + + let (dataset, stats) = job.execute(stream).await.unwrap(); + + // Verify stats + assert_eq!( + stats.num_skipped_duplicates, 3, + "Should have skipped 3 duplicate rows (2 extra for key=2, 1 extra for key=3)" + ); + assert_eq!( + stats.num_updated_rows, 2, + "Should have updated 2 rows (key=2 and key=3)" + ); + assert_eq!( + stats.num_inserted_rows, 1, + "Should have inserted 1 row (key=5)" + ); + + // Verify the actual data - first seen values should be kept + let result_batch = dataset.scan().try_into_batch().await.unwrap(); + let keys = result_batch.column(0).as_primitive::(); + let values = result_batch.column(1).as_primitive::(); + + let result_data: std::collections::HashMap = keys + .values() + .iter() + .zip(values.values().iter()) + .map(|(&k, &v)| (k, v)) + .collect(); + + assert_eq!(result_data.len(), 5, "Should have 5 rows total"); + assert_eq!( + result_data.get(&1), + Some(&10), + "key=1 should be unchanged (original value)" + ); + assert_eq!( + result_data.get(&2), + Some(&100), + "key=2 should have first seen value (100, not 200 or 300)" + ); + assert_eq!( + result_data.get(&3), + Some(&400), + "key=3 should have first seen value (400, not 500)" + ); + assert_eq!( + result_data.get(&4), + Some(&40), + "key=4 should be unchanged (original value)" + ); + assert_eq!( + result_data.get(&5), + Some(&600), + "key=5 should be inserted with value 600" + ); + } + #[tokio::test] async fn test_merge_insert_use_index() { let data = lance_datagen::gen_batch() diff --git a/rust/lance/src/dataset/write/merge_insert/exec.rs b/rust/lance/src/dataset/write/merge_insert/exec.rs index de39d0a1610..473051da181 100644 --- a/rust/lance/src/dataset/write/merge_insert/exec.rs +++ b/rust/lance/src/dataset/write/merge_insert/exec.rs @@ -24,6 +24,7 @@ pub(super) struct MergeInsertMetrics { pub num_deleted_rows: Count, pub bytes_written: Count, pub num_files_written: Count, + pub num_skipped_duplicates: Count, } impl From<&MergeInsertMetrics> for MergeStats { @@ -34,6 +35,7 @@ impl From<&MergeInsertMetrics> for MergeStats { num_updated_rows: value.num_updated_rows.value() as u64, bytes_written: value.bytes_written.value() as u64, num_files_written: value.num_files_written.value() as u64, + num_skipped_duplicates: value.num_skipped_duplicates.value() as u64, num_attempts: 1, } } @@ -46,12 +48,15 @@ impl MergeInsertMetrics { let num_deleted_rows = MetricBuilder::new(metrics).counter("num_deleted_rows", partition); let bytes_written = MetricBuilder::new(metrics).counter("bytes_written", partition); let num_files_written = MetricBuilder::new(metrics).counter("num_files_written", partition); + let num_skipped_duplicates = + MetricBuilder::new(metrics).counter("num_skipped_duplicates", partition); Self { num_inserted_rows, num_updated_rows, num_deleted_rows, bytes_written, num_files_written, + num_skipped_duplicates, } } } diff --git a/rust/lance/src/dataset/write/merge_insert/exec/delete.rs b/rust/lance/src/dataset/write/merge_insert/exec/delete.rs index 1503b5b21c4..69fc71246ea 100644 --- a/rust/lance/src/dataset/write/merge_insert/exec/delete.rs +++ b/rust/lance/src/dataset/write/merge_insert/exec/delete.rs @@ -296,6 +296,7 @@ impl ExecutionPlan for DeleteOnlyMergeInsertExec { bytes_written: 0, num_files_written: 0, num_attempts: 1, + num_skipped_duplicates: 0, }; if let Ok(mut transaction_guard) = transaction_holder.lock() { diff --git a/rust/lance/src/dataset/write/merge_insert/exec/write.rs b/rust/lance/src/dataset/write/merge_insert/exec/write.rs index 4f24c94ca41..3d5527f47bd 100644 --- a/rust/lance/src/dataset/write/merge_insert/exec/write.rs +++ b/rust/lance/src/dataset/write/merge_insert/exec/write.rs @@ -30,7 +30,7 @@ use crate::dataset::write::merge_insert::inserted_rows::{ extract_key_value_from_batch, KeyExistenceFilter, KeyExistenceFilterBuilder, }; use crate::dataset::write::merge_insert::{ - create_duplicate_row_error, format_key_values_on_columns, + create_duplicate_row_error, format_key_values_on_columns, SourceDedupeBehavior, }; use crate::{ dataset::{ @@ -64,6 +64,8 @@ struct MergeState { processed_row_ids: HashSet, /// The "on" column names for merge operation on_columns: Vec, + /// How to handle duplicate source rows + source_dedupe_behavior: SourceDedupeBehavior, } impl MergeState { @@ -72,6 +74,7 @@ impl MergeState { stable_row_ids: bool, on_columns: Vec, field_ids: Vec, + source_dedupe_behavior: SourceDedupeBehavior, ) -> Self { Self { delete_row_addrs: RoaringTreemap::new(), @@ -81,6 +84,7 @@ impl MergeState { stable_row_ids, processed_row_ids: HashSet::new(), on_columns, + source_dedupe_behavior, } } @@ -111,7 +115,19 @@ impl MergeState { // Check for duplicate _rowid in the current merge operation if !self.processed_row_ids.insert(row_id) { - return Err(create_duplicate_row_error(batch, row_idx, &self.on_columns)); + match self.source_dedupe_behavior { + SourceDedupeBehavior::Fail => { + return Err(create_duplicate_row_error( + batch, + row_idx, + &self.on_columns, + )); + } + SourceDedupeBehavior::FirstSeen => { + self.metrics.num_skipped_duplicates.add(1); + return Ok(None); // Skip this duplicate row + } + } } self.delete_row_addrs.insert(row_addr); @@ -829,6 +845,7 @@ impl ExecutionPlan for FullSchemaMergeInsertExec { self.dataset.manifest.uses_stable_row_ids(), self.params.on.clone(), field_ids, + self.params.source_dedupe_behavior, ))); let write_data_stream = self.create_filtered_write_stream(input_stream, merge_state.clone())?; @@ -970,9 +987,15 @@ mod tests { use arrow_array::UInt64Array; #[test] - fn test_merge_state_duplicate_rowid_detection() { + fn test_merge_state_duplicate_rowid_detection_fail() { let metrics = MergeInsertMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let mut merge_state = MergeState::new(metrics, false, Vec::new(), Vec::new()); + let mut merge_state = MergeState::new( + metrics, + false, + Vec::new(), + Vec::new(), + SourceDedupeBehavior::Fail, + ); let row_addr_array = UInt64Array::from(vec![1000, 2000, 3000]); let row_id_array = UInt64Array::from(vec![100, 100, 300]); // Duplicate row_id 100 @@ -1018,4 +1041,66 @@ mod tests { "Third call with different _rowid should succeed" ); } + + #[test] + fn test_merge_state_duplicate_rowid_first_seen() { + let metrics = MergeInsertMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let mut merge_state = MergeState::new( + metrics, + false, + Vec::new(), + Vec::new(), + SourceDedupeBehavior::FirstSeen, + ); + + let row_addr_array = UInt64Array::from(vec![1000, 2000, 3000]); + let row_id_array = UInt64Array::from(vec![100, 100, 300]); // Duplicate row_id 100 + + let result1 = merge_state.process_row_action( + Action::UpdateAll, + 0, + &row_addr_array, + &row_id_array, + &RecordBatch::new_empty(Arc::new(arrow_schema::Schema::empty())), + ); + assert!(result1.is_ok(), "First call should succeed"); + assert_eq!(result1.unwrap(), Some(0), "First row should be kept"); + + let result2 = merge_state.process_row_action( + Action::UpdateAll, + 1, + &row_addr_array, + &row_id_array, + &RecordBatch::new_empty(Arc::new(arrow_schema::Schema::empty())), + ); + assert!( + result2.is_ok(), + "Second call with duplicate _rowid should succeed with FirstSeen" + ); + assert_eq!( + result2.unwrap(), + None, + "Duplicate row should be skipped (return None)" + ); + + // Verify the metric was incremented + assert_eq!( + merge_state.metrics.num_skipped_duplicates.value(), + 1, + "num_skipped_duplicates should be 1" + ); + + let result3 = merge_state.process_row_action( + Action::UpdateAll, + 2, + &row_addr_array, + &row_id_array, + &RecordBatch::new_empty(Arc::new(arrow_schema::Schema::empty())), + ); + assert!( + result3.is_ok(), + "Third call with different _rowid should succeed" + ); + assert_eq!(result3.unwrap(), Some(2), "Third row should be kept"); + } }