diff --git a/java/src/test/java/org/lance/operation/OverwriteTest.java b/java/src/test/java/org/lance/operation/OverwriteTest.java index f5a90de5b49..5ecda106fe4 100644 --- a/java/src/test/java/org/lance/operation/OverwriteTest.java +++ b/java/src/test/java/org/lance/operation/OverwriteTest.java @@ -29,6 +29,8 @@ import java.util.Collections; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; public class OverwriteTest extends OperationTestBase { @@ -64,10 +66,10 @@ void testOverwrite(@TempDir Path tempDir) throws Exception { } } - // Commit fragment again + // Try to commit from stale version (v1) - should fail with retryable error rowCount = 40; fragmentMeta = testDataset.createNewFragment(rowCount); - transaction = + Transaction staleTxn = dataset .newTransactionBuilder() .operation( @@ -78,9 +80,27 @@ void testOverwrite(@TempDir Path tempDir) throws Exception { .build()) .transactionProperties(Collections.singletonMap("key", "value")) .build(); - assertEquals( - "value", transaction.transactionProperties().map(m -> m.get("key")).orElse(null)); - try (Dataset dataset = transaction.commit()) { + assertEquals("value", staleTxn.transactionProperties().map(m -> m.get("key")).orElse(null)); + + RuntimeException ex = assertThrows(RuntimeException.class, () -> staleTxn.commit().close()); + assertTrue( + ex.getMessage().contains("Retryable commit conflict"), + "Expected retryable commit conflict error, got: " + ex.getMessage()); + + // Checkout latest and retry - should succeed + dataset.checkoutLatest(); + Transaction retryTxn = + dataset + .newTransactionBuilder() + .operation( + Overwrite.builder() + .fragments(Collections.singletonList(fragmentMeta)) + .schema(testDataset.getSchema()) + .configUpsertValues(Collections.singletonMap("config_key", "config_value")) + .build()) + .transactionProperties(Collections.singletonMap("key", "value")) + .build(); + try (Dataset dataset = retryTxn.commit()) { assertEquals(3, dataset.version()); assertEquals(3, dataset.latestVersion()); assertEquals(rowCount, dataset.countRows()); @@ -91,7 +111,7 @@ void testOverwrite(@TempDir Path tempDir) throws Exception { Schema schemaRes = scanner.schema(); assertEquals(testDataset.getSchema(), schemaRes); } - assertEquals(transaction, dataset.readTransaction().orElse(null)); + assertEquals(retryTxn, dataset.readTransaction().orElse(null)); } } } diff --git a/rust/lance/src/io/commit.rs b/rust/lance/src/io/commit.rs index 922060b2f73..42ed37a627a 100644 --- a/rust/lance/src/io/commit.rs +++ b/rust/lance/src/io/commit.rs @@ -1300,74 +1300,58 @@ mod tests { #[tokio::test] async fn test_concurrent_writes() { - for write_mode in [WriteMode::Append, WriteMode::Overwrite] { - // Create an empty table - let test_dir = TempStrDir::default(); - let test_uri = test_dir.as_str(); - - let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( - "i", - DataType::Int32, - false, - )])); - - let dataset = Dataset::write( - RecordBatchIterator::new(vec![].into_iter().map(Ok), schema.clone()), - test_uri, - None, - ) - .await - .unwrap(); - - // Make some sample data - let batch = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - ) - .unwrap(); + // Test concurrent appends - all should succeed + let test_dir = TempStrDir::default(); + let test_uri = test_dir.as_str(); - // Write data concurrently in 5 tasks - let futures: Vec<_> = (0..5) - .map(|_| { - let batch = batch.clone(); - let schema = schema.clone(); - let uri = test_uri.to_string(); - tokio::spawn(async move { - let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); - Dataset::write( - reader, - &uri, - Some(WriteParams { - mode: write_mode, - ..Default::default() - }), - ) - .await - }) - }) - .collect(); - let results = join_all(futures).await; + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "i", + DataType::Int32, + false, + )])); - // Assert all succeeded - for result in results { - assert!(matches!(result, Ok(Ok(_))), "{:?}", result); - } + let dataset = Dataset::write( + RecordBatchIterator::new(vec![].into_iter().map(Ok), schema.clone()), + test_uri, + None, + ) + .await + .unwrap(); - // Assert final fragments and versions expected - let dataset = dataset.checkout_version(6).await.unwrap(); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); - match write_mode { - WriteMode::Append => { - assert_eq!(dataset.get_fragments().len(), 5); - } - WriteMode::Overwrite => { - assert_eq!(dataset.get_fragments().len(), 1); - } - _ => unreachable!(), - } + let futures: Vec<_> = (0..5) + .map(|_| { + let batch = batch.clone(); + let schema = schema.clone(); + let uri = test_uri.to_string(); + tokio::spawn(async move { + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + Dataset::write( + reader, + &uri, + Some(WriteParams { + mode: WriteMode::Append, + ..Default::default() + }), + ) + .await + }) + }) + .collect(); + let results = join_all(futures).await; - dataset.validate().await.unwrap() + for result in results { + assert!(matches!(result, Ok(Ok(_))), "{:?}", result); } + + let dataset = dataset.checkout_version(6).await.unwrap(); + assert_eq!(dataset.get_fragments().len(), 5); + dataset.validate().await.unwrap() } #[tokio::test] diff --git a/rust/lance/src/io/commit/conflict_resolver.rs b/rust/lance/src/io/commit/conflict_resolver.rs index 703afbb17e6..fc248998643 100644 --- a/rust/lance/src/io/commit/conflict_resolver.rs +++ b/rust/lance/src/io/commit/conflict_resolver.rs @@ -889,8 +889,24 @@ impl<'a> TransactionRebase<'a> { other_version: u64, ) -> Result<()> { match &other_transaction.operation { - // Overwrite only conflicts with another operation modifying the same update config - Operation::Overwrite { .. } | Operation::UpdateConfig { .. } => { + Operation::Overwrite { .. } => { + if self + .transaction + .operation + .upsert_key_conflict(&other_transaction.operation) + { + Err(self.incompatible_conflict_err( + other_transaction, + other_version, + location!(), + )) + } else { + // Concurrent overwrites are retryable so user can decide + // if their overwrite should still proceed + Err(self.retryable_conflict_err(other_transaction, other_version, location!())) + } + } + Operation::UpdateConfig { .. } => { if self .transaction .operation @@ -1796,6 +1812,7 @@ mod tests { use super::*; use crate::dataset::transaction::{DataReplacementGroup, RewriteGroup}; + use crate::dataset::write::WriteMode; use crate::session::caches::DeletionFileKey; use crate::{ dataset::{CommitBuilder, InsertBuilder, WriteParams}, @@ -2412,9 +2429,19 @@ mod tests { config_upsert_values: None, initial_bases: None, }, - // No conflicts: overwrite can always happen since it doesn't - // depend on previous state of the table. - [Compatible; 9], + // Concurrent overwrites are retryable so user can decide + // if their overwrite should still proceed. + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + Retryable, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + Compatible, // update config + ], ), ( Operation::CreateIndex { @@ -3547,4 +3574,56 @@ mod tests { assert_eq!(rebase.conflicting_mem_wal_merged_gens[0].region_id, region); assert_eq!(rebase.conflicting_mem_wal_merged_gens[0].generation, 10); } + + #[tokio::test] + async fn test_concurrent_overwrites_retryable() { + let dataset = test_dataset(5, 1).await; + let dataset_v1_reader1 = Arc::new(dataset.checkout_version(1).await.unwrap()); + let dataset_v1_reader2 = Arc::new(dataset.checkout_version(1).await.unwrap()); + + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, true), + ])), + vec![ + Arc::new(Int32Array::from_iter_values(10..15)), + Arc::new(Int32Array::from_iter_values(std::iter::repeat_n(1, 5))), + ], + ) + .unwrap(); + + // First overwrite succeeds + let txn1 = InsertBuilder::new(dataset_v1_reader1.clone()) + .with_params(&WriteParams { + mode: WriteMode::Overwrite, + ..Default::default() + }) + .execute_uncommitted(vec![data.clone()]) + .await + .unwrap(); + let dataset_v2 = CommitBuilder::new(dataset_v1_reader1) + .execute(txn1) + .await + .unwrap(); + assert_eq!(dataset_v2.manifest.version, 2); + + // Second overwrite should fail with retryable conflict + let txn2 = InsertBuilder::new(dataset_v1_reader2.clone()) + .with_params(&WriteParams { + mode: WriteMode::Overwrite, + ..Default::default() + }) + .execute_uncommitted(vec![data]) + .await + .unwrap(); + let result = CommitBuilder::new(dataset_v1_reader2).execute(txn2).await; + assert!( + matches!(result, Err(Error::RetryableCommitConflict { .. })), + "Expected RetryableCommitConflict but got: {:?}", + result + ); + + assert_eq!(dataset_v2.count_rows(None).await.unwrap(), 5); + } }