diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index ec0ac823f4d..d3fd1bc0a8c 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -8873,12 +8873,7 @@ mod tests { } fn make_tx(read_version: u64) -> Transaction { - Transaction::new( - read_version, - Operation::Append { fragments: vec![] }, - None, - None, - ) + Transaction::new(read_version, Operation::Append { fragments: vec![] }, None) } async fn delete_external_tx_file(ds: &Dataset) { @@ -8939,7 +8934,6 @@ mod tests { ds.load_indices().await.unwrap().as_ref().clone(), &tx_file, &ManifestWriteConfig::default(), - None, ) .unwrap(); let location = write_manifest_file( diff --git a/rust/lance/src/io/commit/conflict_resolver.rs b/rust/lance/src/io/commit/conflict_resolver.rs index 5e98b634f4e..0968f24db72 100644 --- a/rust/lance/src/io/commit/conflict_resolver.rs +++ b/rust/lance/src/io/commit/conflict_resolver.rs @@ -685,8 +685,24 @@ impl<'a> TransactionRebase<'a> { Ok(()) } } - Operation::DataReplacement { .. } | Operation::Merge { .. } => { - // TODO(rmeng): check that the fragments being replaced are not part of the groups + Operation::DataReplacement { replacements } => { + // These conflict if the rewrite touches any of the fragments being replaced. + for replacement in replacements { + for group in groups { + for old_fragment in &group.old_fragments { + if replacement.0 == old_fragment.id { + return Err(self.retryable_conflict_err( + other_transaction, + other_version, + location!(), + )); + } + } + } + } + Ok(()) + } + Operation::Merge { .. } => { Err(self.retryable_conflict_err(other_transaction, other_version, location!())) } Operation::CreateIndex { @@ -884,21 +900,46 @@ impl<'a> TransactionRebase<'a> { } Ok(()) } - Operation::Rewrite { .. } => { - // TODO(rmeng): check that the fragments being replaced are not part of the groups - Err(self.incompatible_conflict_err( - other_transaction, - other_version, - location!(), - )) + Operation::Rewrite { groups, .. } => { + // These conflict if the rewrite touches any of the fragments being replaced. + for replacement in replacements { + for group in groups { + for old_fragment in &group.old_fragments { + if replacement.0 == old_fragment.id { + return Err(self.retryable_conflict_err( + other_transaction, + other_version, + location!(), + )); + } + } + } + } + + Ok(()) } - Operation::DataReplacement { .. } => { - // TODO(rmeng): check cell conflicts - Err(self.incompatible_conflict_err( - other_transaction, - other_version, - location!(), - )) + Operation::DataReplacement { + replacements: other_replacements, + } => { + // These conflict if there is overlap in fragment id && fields. + for replacement in replacements { + for other_replacement in other_replacements { + if replacement.0 != other_replacement.0 { + continue; + } + + for field in &replacement.1.fields { + if other_replacement.1.fields.contains(field) { + return Err(self.retryable_conflict_err( + other_transaction, + other_version, + location!(), + )); + } + } + } + } + Ok(()) } Operation::Overwrite { .. } | Operation::Restore { .. } @@ -1665,12 +1706,13 @@ mod tests { use lance_table::io::deletion::{deletion_file_path, read_deletion_file}; use super::*; - use crate::dataset::transaction::RewriteGroup; + use crate::dataset::transaction::{DataReplacementGroup, RewriteGroup}; use crate::session::caches::DeletionFileKey; use crate::{ dataset::{CommitBuilder, InsertBuilder, WriteParams}, io, }; + use lance_table::format::DataFile; async fn test_dataset(num_rows: usize, num_fragments: usize) -> (Dataset, Arc) { let io_tracker = Arc::new(IOTracker::default()); @@ -2994,4 +3036,136 @@ mod tests { } } } + + #[tokio::test] + async fn test_conflicts_data_replacement() { + use io::commit::conflict_resolver::tests::{modified_fragment_ids, ConflictResult::*}; + + let fragment0 = Fragment::new(0); + let fragment1 = Fragment::new(1); + + let data_file_frag0_fields01 = + DataFile::new_legacy_from_fields("path0_01", vec![0, 1], None); + let data_file_frag0_fields23 = + DataFile::new_legacy_from_fields("path0_23", vec![2, 3], None); + let data_file_frag1_fields01 = + DataFile::new_legacy_from_fields("path1_01", vec![0, 1], None); + + let cases = vec![ + ( + "Different fragments", + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, data_file_frag0_fields01.clone())], + }, + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(1, data_file_frag1_fields01)], + }, + Compatible, + ), + ( + "Same fragment, different fields", + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, data_file_frag0_fields01.clone())], + }, + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, data_file_frag0_fields23)], + }, + Compatible, + ), + ( + "Same fragment, same fields", + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, data_file_frag0_fields01.clone())], + }, + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, data_file_frag0_fields01.clone())], + }, + Retryable, + ), + ( + "Same fragment, overlapping fields", + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, data_file_frag0_fields01.clone())], + }, + Operation::DataReplacement { + replacements: vec![DataReplacementGroup( + 0, + DataFile::new_legacy_from_fields("path0_12", vec![1, 2], None), + )], + }, + Retryable, + ), + ( + "DataReplacement vs Rewrite on same fragment", + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, data_file_frag0_fields01.clone())], + }, + Operation::Rewrite { + groups: vec![RewriteGroup { + old_fragments: vec![fragment0.clone()], + new_fragments: vec![fragment1.clone()], + }], + rewritten_indices: vec![], + frag_reuse_index: None, + }, + Retryable, + ), + ( + "DataReplacement vs Rewrite on different fragment", + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, data_file_frag0_fields01)], + }, + Operation::Rewrite { + groups: vec![RewriteGroup { + old_fragments: vec![fragment1], + new_fragments: vec![fragment0], + }], + rewritten_indices: vec![], + frag_reuse_index: None, + }, + Compatible, + ), + ]; + + for (description, op1, op2, expected) in cases { + let txn1 = Transaction::new(0, op1.clone(), None); + let txn2 = Transaction::new(0, op2.clone(), None); + + let mut rebase = TransactionRebase { + transaction: txn1, + initial_fragments: HashMap::new(), + modified_fragment_ids: modified_fragment_ids(&op1).collect::>(), + affected_rows: None, + conflicting_frag_reuse_indices: Vec::new(), + }; + + let result = rebase.check_txn(&txn2, 1); + match expected { + Compatible => { + assert!( + result.is_ok(), + "{}: expected Compatible but got {:?}", + description, + result + ); + } + NotCompatible => { + assert!( + matches!(result, Err(Error::CommitConflict { .. })), + "{}: expected NotCompatible but got {:?}", + description, + result + ); + } + Retryable => { + assert!( + matches!(result, Err(Error::RetryableCommitConflict { .. })), + "{}: expected Retryable but got {:?}", + description, + result + ); + } + } + } + } }