diff --git a/java/lance-jni/src/transaction.rs b/java/lance-jni/src/transaction.rs index 103d3a2c2b7..4542b36de4e 100644 --- a/java/lance-jni/src/transaction.rs +++ b/java/lance-jni/src/transaction.rs @@ -464,6 +464,8 @@ fn convert_to_java_operation_inner<'local>( new_fragments, fields_modified: _, mem_wal_to_merge: _, + fields_for_preserving_frag_bitmap: _, + update_mode: _, } => { let removed_ids: Vec> = removed_fragment_ids .iter() @@ -887,6 +889,8 @@ fn convert_to_rust_operation( new_fragments, fields_modified: vec![], mem_wal_to_merge: None, + update_mode: None, + fields_for_preserving_frag_bitmap: vec![], } } "DataReplacement" => { diff --git a/protos/transaction.proto b/protos/transaction.proto index 79bdb5a2fa7..720299eb7ca 100644 --- a/protos/transaction.proto +++ b/protos/transaction.proto @@ -194,6 +194,25 @@ message Transaction { repeated uint32 fields_modified = 4; /// The MemWAL (pre-image) that should be marked as merged after this transaction MemWalIndexDetails.MemWal mem_wal_to_merge = 5; + /// The fields that used to judge whether to preserve the new frag's id into + /// the frag bitmap of the specified indices. + repeated uint32 fields_for_preserving_frag_bitmap = 6; + // The mode of update + UpdateMode update_mode = 7; + } + + // The mode of update operation + enum UpdateMode { + + /// rows are deleted in current fragments and rewritten in new fragments. + /// This is most optimal when the majority of columns are being rewritten + /// or only a few rows are being updated. + REWRITE_ROWS = 0; + + /// within each fragment, columns are fully rewritten and inserted as new data files. + /// Old versions of columns are tombstoned. This is most optimal when most rows are affected + /// but a small subset of columns are affected. + REWRITE_COLUMNS = 1; } // An operation that updates the table config. diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 2a8123120c2..04829c05dce 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -3632,12 +3632,17 @@ class Update(BaseOperation): If any fields are modified in updated_fragments, then they must be listed here so those fragments can be removed from indices covering those fields. + fields_for_preserving_frag_bitmap: list[int] + The fields that used to judge whether to preserve the new frag's id into + the frag bitmap of the specified indices. """ removed_fragment_ids: List[int] updated_fragments: List[FragmentMetadata] new_fragments: List[FragmentMetadata] fields_modified: List[int] + fields_for_preserving_frag_bitmap: List[int] + update_mode: str def __post_init__(self): LanceOperation._validate_fragments(self.updated_fragments) diff --git a/python/src/transaction.rs b/python/src/transaction.rs index 64dabce601e..4ac763be84d 100644 --- a/python/src/transaction.rs +++ b/python/src/transaction.rs @@ -1,10 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use crate::schema::LanceSchema; +use crate::utils::{class_name, export_vec, extract_vec, PyLance}; use arrow::pyarrow::PyArrowType; use arrow_schema::Schema as ArrowSchema; use lance::dataset::transaction::{ - DataReplacementGroup, Operation, RewriteGroup, RewrittenIndex, Transaction, + DataReplacementGroup, Operation, RewriteGroup, RewrittenIndex, Transaction, UpdateMode, }; use lance::datatypes::Schema; use lance_table::format::{DataFile, Fragment, Index}; @@ -17,9 +19,6 @@ use std::collections::HashMap; use std::sync::Arc; use uuid::Uuid; -use crate::schema::LanceSchema; -use crate::utils::{class_name, export_vec, extract_vec, PyLance}; - // Add Index bindings impl FromPyObject<'_> for PyLance { fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { @@ -141,6 +140,23 @@ impl<'py> IntoPyObject<'py> for PyLance<&DataReplacementGroup> { } } +#[derive(Debug, Clone)] +pub struct PyUpdateMode(pub UpdateMode); + +impl FromPyObject<'_> for PyUpdateMode { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + let mode_str: String = ob.extract()?; + match mode_str.as_str() { + "rewrite_rows" => Ok(Self(UpdateMode::RewriteRows)), + "rewrite_columns" => Ok(Self(UpdateMode::RewriteColumns)), + _ => Err(PyValueError::new_err(format!( + "Invalid UpdateMode: {}. Valid options are: rewrite_rows, rewrite_columns", + mode_str + ))), + } + } +} + impl FromPyObject<'_> for PyLance { fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { match class_name(ob)?.as_str() { @@ -182,12 +198,25 @@ impl FromPyObject<'_> for PyLance { let fields_modified = ob.getattr("fields_modified")?.extract()?; + let fields_for_preserving_frag_bitmap = ob + .getattr("fields_for_preserving_frag_bitmap")? + .extract() + .unwrap_or_default(); + + let update_mode = ob + .getattr("update_mode")? + .extract::() + .ok() + .map(|py_mode| py_mode.0); + let op = Operation::Update { removed_fragment_ids, updated_fragments, new_fragments, fields_modified, mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap, + update_mode, }; Ok(Self(op)) } @@ -290,12 +319,25 @@ impl<'py> IntoPyObject<'py> for PyLance<&Operation> { updated_fragments, new_fragments, fields_modified, + fields_for_preserving_frag_bitmap, + update_mode, .. } => { let removed_fragment_ids = removed_fragment_ids.into_pyobject(py)?; let updated_fragments = export_vec(py, updated_fragments.as_slice())?; let new_fragments = export_vec(py, new_fragments.as_slice())?; let fields_modified = fields_modified.into_pyobject(py)?; + let fields_for_preserving_frag_bitmap = + fields_for_preserving_frag_bitmap.into_pyobject(py)?; + let update_mode = match update_mode { + Some(mode) => match mode { + lance::dataset::transaction::UpdateMode::RewriteRows => "rewrite_rows", + lance::dataset::transaction::UpdateMode::RewriteColumns => { + "rewrite_columns" + } + }, + None => "rewrite_rows", + }; let cls = namespace .getattr("Update") .expect("Failed to get Update class"); @@ -304,6 +346,8 @@ impl<'py> IntoPyObject<'py> for PyLance<&Operation> { updated_fragments, new_fragments, fields_modified, + fields_for_preserving_frag_bitmap, + update_mode, )) } Operation::DataReplacement { replacements } => { diff --git a/rust/lance/src/dataset/transaction.rs b/rust/lance/src/dataset/transaction.rs index 8f46c45872b..f9813586ea9 100644 --- a/rust/lance/src/dataset/transaction.rs +++ b/rust/lance/src/dataset/transaction.rs @@ -46,6 +46,7 @@ //! use super::ManifestWriteConfig; +use crate::dataset::transaction::UpdateMode::RewriteRows; use crate::index::mem_wal::update_mem_wal_index_in_indices_list; use crate::utils::temporal::timestamp_to_nanos; use deepsize::DeepSizeOf; @@ -212,6 +213,11 @@ pub enum Operation { fields_modified: Vec, /// The MemWAL (pre-image) that should be marked as merged after this transaction mem_wal_to_merge: Option, + /// The fields that used to judge whether to preserve the new frag's id into + /// the frag bitmap of the specified indices. + fields_for_preserving_frag_bitmap: Vec, + /// The mode of update + update_mode: Option, }, /// Project to a new schema. This only changes the schema, not the data. @@ -240,6 +246,19 @@ pub enum Operation { }, } +#[derive(Debug, Clone, PartialEq, DeepSizeOf)] +pub enum UpdateMode { + /// rows are deleted in current fragments and rewritten in new fragments. + /// This is most optimal when the majority of columns are being rewritten + /// or only a few rows are being updated. + RewriteRows, + + /// within each fragment, columns are fully rewritten and inserted as new data files. + /// Old versions of columns are tombstoned. This is most optimal when most rows are affected + /// but a small subset of columns are affected. + RewriteColumns, +} + impl std::fmt::Display for Operation { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -372,6 +391,8 @@ impl PartialEq for Operation { new_fragments: a_new, fields_modified: a_fields, mem_wal_to_merge: a_mem_wal_to_merge, + fields_for_preserving_frag_bitmap: a_fields_for_preserving_frag_bitmap, + update_mode: a_update_mode, }, Self::Update { removed_fragment_ids: b_removed, @@ -379,6 +400,8 @@ impl PartialEq for Operation { new_fragments: b_new, fields_modified: b_fields, mem_wal_to_merge: b_mem_wal_to_merge, + fields_for_preserving_frag_bitmap: b_fields_for_preserving_frag_bitmap, + update_mode: b_update_mode, }, ) => { compare_vec(a_removed, b_removed) @@ -386,6 +409,11 @@ impl PartialEq for Operation { && compare_vec(a_new, b_new) && compare_vec(a_fields, b_fields) && a_mem_wal_to_merge == b_mem_wal_to_merge + && compare_vec( + a_fields_for_preserving_frag_bitmap, + b_fields_for_preserving_frag_bitmap, + ) + && a_update_mode == b_update_mode } (Self::Project { schema: a }, Self::Project { schema: b }) => a == b, ( @@ -1395,6 +1423,8 @@ impl Transaction { new_fragments, fields_modified, mem_wal_to_merge, + fields_for_preserving_frag_bitmap, + update_mode, } => { final_fragments.extend(maybe_existing_fragments?.iter().filter_map(|f| { if removed_fragment_ids.contains(&f.id) { @@ -1417,6 +1447,29 @@ impl Transaction { let mut new_fragments = Self::fragments_with_ids(new_fragments.clone(), &mut fragment_id) .collect::>(); + + if config.use_stable_row_ids + && update_mode.is_some() + && *update_mode == Some(RewriteRows) + { + let pure_updated_frag_ids = + Self::collect_pure_rewrite_row_update_frags_ids(&new_fragments)?; + + // collect all the original frag ids that contains the updated rows + let original_fragment_ids: Vec = removed_fragment_ids + .iter() + .chain(updated_fragments.iter().map(|f| &f.id)) + .copied() + .collect(); + + Self::register_pure_rewrite_rows_update_frags_in_indices( + &mut final_indices, + &pure_updated_frag_ids, + &original_fragment_ids, + fields_for_preserving_frag_bitmap, + ); + } + if let Some(next_row_id) = &mut next_row_id { Self::assign_row_ids(next_row_id, new_fragments.as_mut_slice())?; } @@ -1752,6 +1805,45 @@ impl Transaction { Ok((manifest, final_indices)) } + fn register_pure_rewrite_rows_update_frags_in_indices( + indices: &mut [Index], + pure_update_frag_ids: &[u64], + original_fragment_ids: &[u64], + fields_for_preserving_frag_bitmap: &[u32], + ) { + if pure_update_frag_ids.is_empty() { + return; + } + + let value_updated_field_set = fields_for_preserving_frag_bitmap + .iter() + .collect::>(); + + for index in indices.iter_mut() { + let index_covers_modified_field = index.fields.iter().any(|field_id| { + value_updated_field_set.contains(&u32::try_from(*field_id).unwrap()) + }); + + if !index_covers_modified_field { + if let Some(fragment_bitmap) = &mut index.fragment_bitmap { + // check if all the original fragments contains the updating rows are covered + // by the index(index fragment bitmap contains these frag ids). + // if not, that means not all the updating rows are indexed, so we could not + // index them. + let index_covers_all_original_fragments = original_fragment_ids + .iter() + .all(|&fragment_id| fragment_bitmap.contains(fragment_id as u32)); + + if index_covers_all_original_fragments { + for fragment_id in pure_update_frag_ids.iter().map(|f| *f as u32) { + fragment_bitmap.insert(fragment_id); + } + } + } + } + } + } + /// If an operation modifies one or more fields in a fragment then we need to remove /// that fragment from any indices that cover one of the modified fields. fn prune_updated_fields_from_indices( @@ -2005,6 +2097,36 @@ impl Transaction { Ok(()) } + /// collect the pure(the num of row IDs are equal to the physical rows) "rewrite rows" updated fragment ids + fn collect_pure_rewrite_row_update_frags_ids(fragments: &[Fragment]) -> Result> { + let mut pure_update_frag_ids = Vec::new(); + + for fragment in fragments { + let physical_rows = fragment.physical_rows.ok_or_else(|| Error::Internal { + message: "Fragment does not have physical rows".into(), + location: location!(), + })? as u64; + + if let Some(row_id_meta) = &fragment.row_id_meta { + let existing_row_count = match row_id_meta { + RowIdMeta::Inline(data) => { + let sequence = read_row_ids(data)?; + sequence.len() as u64 + } + _ => 0, + }; + + // only filter the fragments that match: all the rows have row id, + // which means it does not contain inserted rows in this fragment + if existing_row_count == physical_rows { + pure_update_frag_ids.push(fragment.id); + } + } + } + + Ok(pure_update_frag_ids) + } + fn assign_row_ids(next_row_id: &mut u64, fragments: &mut [Fragment]) -> Result<()> { for fragment in fragments { let physical_rows = fragment.physical_rows.ok_or_else(|| Error::Internal { @@ -2239,6 +2361,8 @@ impl TryFrom for Transaction { new_fragments, fields_modified, mem_wal_to_merge, + fields_for_preserving_frag_bitmap, + update_mode, })) => Operation::Update { removed_fragment_ids, updated_fragments: updated_fragments @@ -2251,6 +2375,12 @@ impl TryFrom for Transaction { .collect::>>()?, fields_modified, mem_wal_to_merge: mem_wal_to_merge.map(|m| MemWal::try_from(m).unwrap()), + fields_for_preserving_frag_bitmap, + update_mode: match update_mode { + 0 => Some(UpdateMode::RewriteRows), + 1 => Some(UpdateMode::RewriteColumns), + _ => Some(UpdateMode::RewriteRows), + }, }, Some(pb::transaction::Operation::Project(pb::transaction::Project { schema })) => { Operation::Project { @@ -2528,6 +2658,8 @@ impl From<&Transaction> for pb::Transaction { new_fragments, fields_modified, mem_wal_to_merge, + fields_for_preserving_frag_bitmap, + update_mode, } => pb::transaction::Operation::Update(pb::transaction::Update { removed_fragment_ids: removed_fragment_ids.clone(), updated_fragments: updated_fragments @@ -2536,9 +2668,15 @@ impl From<&Transaction> for pb::Transaction { .collect(), new_fragments: new_fragments.iter().map(pb::DataFragment::from).collect(), fields_modified: fields_modified.clone(), - mem_wal_to_merge: mem_wal_to_merge + mem_wal_to_merge: mem_wal_to_merge.as_ref().map(|m| m.into()), + fields_for_preserving_frag_bitmap: fields_for_preserving_frag_bitmap.clone(), + update_mode: update_mode .as_ref() - .map(pb::mem_wal_index_details::MemWal::from), + .map(|mode| match mode { + UpdateMode::RewriteRows => 0, + UpdateMode::RewriteColumns => 1, + }) + .unwrap_or(0), }), Operation::Project { schema } => { pb::transaction::Operation::Project(pb::transaction::Project { diff --git a/rust/lance/src/dataset/write/commit.rs b/rust/lance/src/dataset/write/commit.rs index abfac37d415..9a8991a78c9 100644 --- a/rust/lance/src/dataset/write/commit.rs +++ b/rust/lance/src/dataset/write/commit.rs @@ -789,6 +789,8 @@ mod tests { removed_fragment_ids: vec![], fields_modified: vec![], mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap: vec![], + update_mode: None, }, read_version: 1, blobs_op: None, diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index a11c3c8eeee..7cd8d687d1c 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -24,6 +24,7 @@ use assign_action::merge_insert_action; use super::retry::{execute_with_retry, RetryConfig, RetryExecutor}; use super::{write_fragments_internal, CommitBuilder, WriteParams}; use crate::dataset::rowids::get_row_id_index; +use crate::dataset::transaction::UpdateMode::{RewriteColumns, RewriteRows}; use crate::dataset::utils::CapturedRowIds; use crate::{ datafusion::dataframe::SessionContextExt, @@ -1420,6 +1421,8 @@ impl MergeInsertJob { new_fragments, fields_modified, mem_wal_to_merge: self.params.mem_wal_to_merge, + fields_for_preserving_frag_bitmap: vec![], // in-place update do not affect preserving frag bitmap + update_mode: Some(RewriteColumns), }; // We have rewritten the fragments, not just the deletion files, so // we can't use affected rows here. @@ -1490,6 +1493,12 @@ impl MergeInsertJob { // modify any field values. fields_modified: vec![], mem_wal_to_merge: self.params.mem_wal_to_merge, + fields_for_preserving_frag_bitmap: full_schema + .fields + .iter() + .map(|f| f.id as u32) + .collect(), + update_mode: Some(RewriteRows), }; let affected_rows = Some(RowIdTreeMap::from(removed_row_addrs)); @@ -2031,6 +2040,7 @@ impl Merger { mod tests { use super::*; use crate::dataset::scanner::ColumnOrdering; + use crate::index::vector::VectorIndexParams; use crate::{ dataset::{builder::DatasetBuilder, InsertBuilder, ReadParams, WriteMode, WriteParams}, session::Session, @@ -2039,18 +2049,23 @@ mod tests { FragmentRowCount, ThrottledStoreWrapper, }, }; + use arrow_array::types::Float32Type; use arrow_array::{ types::{Int32Type, UInt32Type}, - Int32Array, Int64Array, RecordBatchIterator, RecordBatchReader, StringArray, UInt32Array, + FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatchIterator, + RecordBatchReader, StringArray, UInt32Array, }; use arrow_select::concat::concat_batches; use datafusion::common::Column; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use futures::{future::try_join_all, FutureExt, StreamExt, TryStreamExt}; + use lance_arrow::FixedSizeListArrayExt; use lance_datafusion::{datagen::DatafusionDatagenExt, utils::reader_to_stream}; - use lance_datagen::{array, BatchCount, RowCount, Seed}; - use lance_index::{scalar::ScalarIndexParams, IndexType}; + use lance_datagen::{array, BatchCount, Dimension, RowCount, Seed}; + use lance_index::scalar::ScalarIndexParams; + use lance_index::IndexType; use lance_io::object_store::ObjectStoreParams; + use lance_linalg::distance::MetricType; use object_store::throttle::ThrottleConfig; use roaring::RoaringBitmap; use std::collections::HashMap; @@ -4520,4 +4535,259 @@ MergeInsert: on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_n .unwrap(); assert_eq!(updated_count, 3); } + + #[tokio::test] + async fn test_full_schema_upsert_fragment_bitmap() { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, true), + Field::new("value", DataType::UInt32, true), + Field::new( + "vec", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + true, + ), + ])); + + let mut dataset = lance_datagen::gen_batch() + .col("key", array::step_custom::(1, 1)) + .col("value", array::step_custom::(10, 10)) + .col( + "vec", + array::cycle_vec( + array::cycle::(vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, + 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, + ]), + Dimension::from(4), + ), + ) + .into_ram_dataset_with_params( + FragmentCount::from(2), + FragmentRowCount::from(3), + Some(WriteParams { + max_rows_per_file: 3, + enable_stable_row_ids: true, + ..Default::default() + }), + ) + .await + .unwrap(); + + let scalar_params = ScalarIndexParams::default(); + dataset + .create_index( + &["value"], + IndexType::Scalar, + Some("value_idx".to_string()), + &scalar_params, + true, + ) + .await + .unwrap(); + + let vector_params = VectorIndexParams::ivf_flat(1, MetricType::L2); + dataset + .create_index( + &["vec"], + IndexType::Vector, + Some("vec_idx".to_string()), + &vector_params, + true, + ) + .await + .unwrap(); + + let indices = dataset.load_indices().await.unwrap(); + let value_index = indices.iter().find(|idx| idx.name == "value_idx").unwrap(); + let vec_index = indices.iter().find(|idx| idx.name == "vec_idx").unwrap(); + + assert_eq!( + value_index + .fragment_bitmap + .as_ref() + .unwrap() + .iter() + .collect::>(), + vec![0, 1] + ); + assert_eq!( + vec_index + .fragment_bitmap + .as_ref() + .unwrap() + .iter() + .collect::>(), + vec![0, 1] + ); + + // update keys: 2,5 + let upsert_keys = UInt32Array::from(vec![2, 5]); + let upsert_values = UInt32Array::from(vec![200, 500]); + let upsert_vecs = FixedSizeListArray::try_new_from_values( + Float32Array::from(vec![21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0]), + 4, + ) + .unwrap(); + + let upsert_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(upsert_keys), + Arc::new(upsert_values), + Arc::new(upsert_vecs), + ], + ) + .unwrap(); + + let upsert_stream = RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::once(async { Ok(upsert_batch) }).boxed(), + ); + + let (updated_dataset, _stats) = + MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::DoNothing) + .when_not_matched_by_source(WhenNotMatchedBySource::Keep) + .try_build() + .unwrap() + .execute(Box::pin(upsert_stream)) + .await + .unwrap(); + + let fragments = updated_dataset.get_fragments(); + assert_eq!(fragments.len(), 3); + } + + #[tokio::test] + async fn test_sub_schema_upsert_fragment_bitmap() { + let mut dataset = lance_datagen::gen_batch() + .col("key", array::step_custom::(1, 1)) + .col("value", array::step_custom::(10, 10)) + .col( + "vec", + array::cycle_vec( + array::cycle::(vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, + 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, + ]), + Dimension::from(4), + ), + ) + .into_ram_dataset_with_params( + FragmentCount::from(2), + FragmentRowCount::from(3), + Some(WriteParams { + max_rows_per_file: 3, + enable_stable_row_ids: true, + ..Default::default() + }), + ) + .await + .unwrap(); + + let scalar_params = ScalarIndexParams::default(); + dataset + .create_index( + &["value"], + IndexType::Scalar, + Some("value_idx".to_string()), + &scalar_params, + true, + ) + .await + .unwrap(); + + let vector_params = VectorIndexParams::ivf_flat(1, MetricType::L2); + dataset + .create_index( + &["vec"], + IndexType::Vector, + Some("vec_idx".to_string()), + &vector_params, + true, + ) + .await + .unwrap(); + + let indices = dataset.load_indices().await.unwrap(); + let value_index = indices.iter().find(|idx| idx.name == "value_idx").unwrap(); + let vec_index = indices.iter().find(|idx| idx.name == "vec_idx").unwrap(); + + assert_eq!( + value_index + .fragment_bitmap + .as_ref() + .unwrap() + .iter() + .collect::>(), + vec![0, 1] + ); + assert_eq!( + vec_index + .fragment_bitmap + .as_ref() + .unwrap() + .iter() + .collect::>(), + vec![0, 1] + ); + + let sub_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, true), + Field::new( + "vec", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + true, + ), + ])); + + let upsert_keys = UInt32Array::from(vec![2, 5]); + let upsert_vecs = FixedSizeListArray::try_new_from_values( + Float32Array::from(vec![21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0]), + 4, + ) + .unwrap(); + + let upsert_batch = RecordBatch::try_new( + sub_schema.clone(), + vec![Arc::new(upsert_keys), Arc::new(upsert_vecs)], + ) + .unwrap(); + + let upsert_stream = RecordBatchStreamAdapter::new( + sub_schema.clone(), + futures::stream::once(async { Ok(upsert_batch) }).boxed(), + ); + + let (updated_dataset, _stats) = + MergeInsertBuilder::try_new(Arc::new(dataset), vec!["key".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::DoNothing) + .when_not_matched_by_source(WhenNotMatchedBySource::Keep) + .try_build() + .unwrap() + .execute(Box::pin(upsert_stream)) + .await + .unwrap(); + + let fragments = updated_dataset.get_fragments(); + // in-place updates only, no new fragment should be added + assert_eq!(fragments.len(), 2); + + let updated_indices = updated_dataset.load_indices().await.unwrap(); + // all the fragments have been updated, so the index of the vector field has been deleted + assert_eq!(updated_indices.len(), 1); + let updated_value_index = updated_indices + .iter() + .find(|idx| idx.name == "value_idx") + .unwrap(); + + let value_bitmap = updated_value_index.fragment_bitmap.as_ref().unwrap(); + assert_eq!(value_bitmap.len(), 2); + assert!(value_bitmap.contains(0)); + assert!(value_bitmap.contains(1)); + } } diff --git a/rust/lance/src/dataset/write/merge_insert/exec/write.rs b/rust/lance/src/dataset/write/merge_insert/exec/write.rs index f4d6c7101d1..94a23490422 100644 --- a/rust/lance/src/dataset/write/merge_insert/exec/write.rs +++ b/rust/lance/src/dataset/write/merge_insert/exec/write.rs @@ -21,6 +21,7 @@ use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use futures::{stream, StreamExt}; use roaring::RoaringTreemap; +use crate::dataset::transaction::UpdateMode::RewriteRows; use crate::dataset::utils::CapturedRowIds; use crate::dataset::write::merge_insert::create_duplicate_row_error; use crate::{ @@ -868,6 +869,13 @@ impl ExecutionPlan for FullSchemaMergeInsertExec { new_fragments, fields_modified: vec![], // No fields are modified in schema for upsert mem_wal_to_merge, + fields_for_preserving_frag_bitmap: dataset + .schema() + .fields + .iter() + .map(|f| f.id as u32) + .collect(), + update_mode: Some(RewriteRows), }; // Step 5: Create and store the transaction diff --git a/rust/lance/src/dataset/write/update.rs b/rust/lance/src/dataset/write/update.rs index c1860e1c94c..ac1b92535f8 100644 --- a/rust/lance/src/dataset/write/update.rs +++ b/rust/lance/src/dataset/write/update.rs @@ -5,8 +5,10 @@ use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use std::time::Duration; +use super::retry::{execute_with_retry, RetryConfig, RetryExecutor}; use super::{write_fragments_internal, CommitBuilder, WriteParams}; use crate::dataset::rowids::get_row_id_index; +use crate::dataset::transaction::UpdateMode::RewriteRows; use crate::dataset::transaction::{Operation, Transaction}; use crate::dataset::utils::make_rowid_capture_stream; use crate::{io::exec::Planner, Dataset}; @@ -30,8 +32,6 @@ use lance_table::format::{Fragment, RowIdMeta}; use roaring::RoaringTreemap; use snafu::{location, ResultExt}; -use super::retry::{execute_with_retry, RetryConfig, RetryExecutor}; - /// Build an update operation. /// /// This operation is similar to SQL's UPDATE statement. It allows you to change @@ -275,7 +275,7 @@ impl UpdateJob { retry_timeout: self.retry_timeout, }; - execute_with_retry(self, dataset, config).await + Box::pin(execute_with_retry(self, dataset, config)).await } async fn execute_impl(self) -> Result { @@ -392,15 +392,27 @@ impl UpdateJob { dataset: Arc, update_data: UpdateData, ) -> Result { + let mut fields_for_preserving_frag_bitmap = Vec::new(); + for column_name in self.updates.keys() { + if let Ok(field_id) = dataset.schema().field_id(column_name) { + fields_for_preserving_frag_bitmap.push(field_id as u32); + } + } + // Commit updated and new fragments let operation = Operation::Update { removed_fragment_ids: update_data.removed_fragment_ids, updated_fragments: update_data.old_fragments, new_fragments: update_data.new_fragments, - // This job only deletes rows, it does not modify any field values. + // In "rewrite rows" mode, the rows that are updated in the fragment + // are moved(deleted and appended). + // so we do not need to handle the frag bitmap of the index about it. fields_modified: vec![], mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap, + update_mode: Some(RewriteRows), }; + let transaction = Transaction::new( dataset.manifest.version, operation, @@ -509,14 +521,23 @@ mod tests { use super::*; + use crate::dataset::{WriteDestination, WriteMode}; + use crate::index::vector::VectorIndexParams; + use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; use arrow::{array::AsArray, datatypes::UInt32Type}; + use arrow_array::types::Float32Type; use arrow_array::{Int64Array, RecordBatchIterator, StringArray, UInt32Array, UInt64Array}; use arrow_schema::{Field, Schema as ArrowSchema}; use arrow_select::concat::concat_batches; use futures::{future::try_join_all, TryStreamExt}; use lance_core::ROW_ID; + use lance_datagen::{Dimension, RowCount}; use lance_file::version::LanceFileVersion; + use lance_index::scalar::ScalarIndexParams; + use lance_index::DatasetIndexExt; + use lance_index::IndexType; use lance_io::object_store::ObjectStoreParams; + use lance_linalg::distance::MetricType; use object_store::throttle::ThrottleConfig; use rstest::rstest; use tempfile::{tempdir, TempDir}; @@ -1048,4 +1069,297 @@ mod tests { } } } + + #[tokio::test] + async fn test_update_affects_index_fragment_bitmap() { + let mut dataset = lance_datagen::gen_batch() + .col( + "str", + lance_datagen::array::cycle_utf8_literals(&["a", "b", "c", "d", "e", "f"]), + ) + .col( + "vec", + lance_datagen::array::rand_vec::(Dimension::from(4)), + ) + .into_ram_dataset_with_params( + FragmentCount::from(2), + FragmentRowCount::from(3), + Some(WriteParams { + max_rows_per_file: 3, + enable_stable_row_ids: true, + ..Default::default() + }), + ) + .await + .unwrap(); + + let scalar_params = ScalarIndexParams::default(); + dataset + .create_index( + &["str"], + IndexType::Scalar, + Some("str_idx".to_string()), + &scalar_params, + true, + ) + .await + .unwrap(); + + let vector_params = VectorIndexParams::ivf_flat(1, MetricType::L2); + dataset + .create_index( + &["vec"], + IndexType::Vector, + Some("vec_idx".to_string()), + &vector_params, + true, + ) + .await + .unwrap(); + + let indices = dataset.load_indices().await.unwrap(); + let str_index = indices.iter().find(|idx| idx.name == "str_idx").unwrap(); + let vec_index = indices.iter().find(|idx| idx.name == "vec_idx").unwrap(); + + assert_eq!( + str_index + .fragment_bitmap + .as_ref() + .unwrap() + .iter() + .collect::>(), + vec![0, 1] + ); + assert_eq!( + vec_index + .fragment_bitmap + .as_ref() + .unwrap() + .iter() + .collect::>(), + vec![0, 1] + ); + + let updated_dataset = UpdateBuilder::new(Arc::new(dataset)) + .update_where("str = 'e'") + .unwrap() + .set("vec", "array[25.0, 26.0, 27.0, 28.0]") + .unwrap() + .build() + .unwrap() + .execute() + .await + .unwrap() + .new_dataset; + + let updated_indices = updated_dataset.load_indices().await.unwrap(); + let updated_str_index = updated_indices + .iter() + .find(|idx| idx.name == "str_idx") + .unwrap(); + let updated_vec_index = updated_indices + .iter() + .find(|idx| idx.name == "vec_idx") + .unwrap(); + + let str_bitmap = updated_str_index.fragment_bitmap.as_ref().unwrap(); + assert_eq!(str_bitmap.len(), 3); + assert_eq!(str_bitmap.iter().collect::>(), vec![0, 1, 2]); + + let vec_bitmap = updated_vec_index.fragment_bitmap.as_ref().unwrap(); + assert_eq!(vec_bitmap.len(), 2); + assert_eq!(vec_bitmap.iter().collect::>(), vec![0, 1]); + + let fragments = updated_dataset.get_fragments(); + assert!(fragments.len() > 2); + + let second_fragment = &fragments[1]; + assert!(second_fragment + .get_deletion_vector() + .await + .unwrap() + .is_some()); + } + + #[tokio::test] + async fn test_update_mixed_indexed_unindexed_fragments() { + let mut dataset = lance_datagen::gen_batch() + .col( + "str", + lance_datagen::array::cycle_utf8_literals(&["a", "b", "c", "d", "e", "f"]), + ) + .col( + "vec", + lance_datagen::array::rand_vec::(Dimension::from(4)), + ) + .into_ram_dataset_with_params( + FragmentCount::from(2), + FragmentRowCount::from(3), + Some(WriteParams { + max_rows_per_file: 3, + enable_stable_row_ids: true, + ..Default::default() + }), + ) + .await + .unwrap(); + + dataset + .create_index( + &["str"], + IndexType::Scalar, + Some("str_idx".to_string()), + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + dataset + .create_index( + &["vec"], + IndexType::Vector, + Some("vec_idx".to_string()), + &VectorIndexParams::ivf_flat(1, MetricType::L2), + true, + ) + .await + .unwrap(); + + let initial_indices = dataset.load_indices().await.unwrap(); + let str_index = initial_indices + .iter() + .find(|idx| idx.name == "str_idx") + .unwrap(); + let vec_index = initial_indices + .iter() + .find(|idx| idx.name == "vec_idx") + .unwrap(); + + assert_eq!( + str_index + .fragment_bitmap + .as_ref() + .unwrap() + .iter() + .collect::>(), + vec![0, 1] + ); + assert_eq!( + vec_index + .fragment_bitmap + .as_ref() + .unwrap() + .iter() + .collect::>(), + vec![0, 1] + ); + + // insert data to create the third frag + let new_batch = lance_datagen::gen_batch() + .col( + "str", + lance_datagen::array::cycle_utf8_literals(&["g", "h", "i"]), + ) + .col( + "vec", + lance_datagen::array::rand_vec::(Dimension::from(4)), + ) + .into_batch_rows(RowCount::from(3)) + .unwrap(); + + dataset = InsertBuilder::new(WriteDestination::Dataset(Arc::new(dataset))) + .with_params(&WriteParams { + mode: WriteMode::Append, + enable_stable_row_ids: true, + ..Default::default() + }) + .execute(vec![new_batch]) + .await + .unwrap(); + + assert_eq!(dataset.get_fragments().len(), 3); + + let indices_after_insert = dataset.load_indices().await.unwrap(); + let str_index_after_insert = indices_after_insert + .iter() + .find(|idx| idx.name == "str_idx") + .unwrap(); + let vec_index_after_insert = indices_after_insert + .iter() + .find(|idx| idx.name == "vec_idx") + .unwrap(); + + assert_eq!( + str_index_after_insert + .fragment_bitmap + .as_ref() + .unwrap() + .len(), + 2 + ); + assert!(!str_index_after_insert + .fragment_bitmap + .as_ref() + .unwrap() + .contains(2)); + assert_eq!( + vec_index_after_insert + .fragment_bitmap + .as_ref() + .unwrap() + .len(), + 2 + ); + assert!(!vec_index_after_insert + .fragment_bitmap + .as_ref() + .unwrap() + .contains(2)); + + let updated_dataset = UpdateBuilder::new(Arc::new(dataset)) + // 'a' in fragment 0,'g' in fragment 2, and frag 2 not in frag bitmap + .update_where("str = 'a' OR str = 'g'") + .unwrap() + .set("vec", "array[99.0, 99.0, 99.0, 99.0]") + .unwrap() + .build() + .unwrap() + .execute() + .await + .unwrap() + .new_dataset; + + // reload indices + let updated_indices = updated_dataset.load_indices().await.unwrap(); + let updated_str_index = updated_indices + .iter() + .find(|idx| idx.name == "str_idx") + .unwrap(); + let updated_vec_index = updated_indices + .iter() + .find(|idx| idx.name == "vec_idx") + .unwrap(); + + let str_bitmap = updated_str_index.fragment_bitmap.as_ref().unwrap(); + let vec_bitmap = updated_vec_index.fragment_bitmap.as_ref().unwrap(); + + assert!(updated_dataset.get_fragments().len() > 3); + assert_eq!(str_bitmap.len(), 2); + assert_eq!(vec_bitmap.len(), 2); + + // frag 3 not in the index's frag bitmap + for &fragment_id in str_bitmap.iter().collect::>().iter() { + assert!(fragment_id < 2, + "str index bitmap should not contain fragments with unindexed data, found fragment {}", + fragment_id); + } + + // frag 3 not in the index's frag bitmap + for &fragment_id in vec_bitmap.iter().collect::>().iter() { + assert!(fragment_id < 2, + "vec index bitmap should not contain fragments with unindexed data, found fragment {}", + fragment_id); + } + } } diff --git a/rust/lance/src/io/commit/conflict_resolver.rs b/rust/lance/src/io/commit/conflict_resolver.rs index 52deab0a85d..490d92fc1da 100644 --- a/rust/lance/src/io/commit/conflict_resolver.rs +++ b/rust/lance/src/io/commit/conflict_resolver.rs @@ -1589,6 +1589,8 @@ mod tests { new_fragments: vec![], fields_modified: vec![], mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap: vec![], + update_mode: None, }; let transaction = Transaction::new_from_version(1, operation); let other_operations = [ @@ -1598,6 +1600,8 @@ mod tests { new_fragments: vec![], fields_modified: vec![], mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap: vec![], + update_mode: None, }, Operation::Delete { deleted_fragment_ids: vec![3], @@ -1610,6 +1614,8 @@ mod tests { new_fragments: vec![], fields_modified: vec![], mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap: vec![], + update_mode: None, }, ]; let other_transactions = other_operations.map(|op| Transaction::new_from_version(2, op)); @@ -1709,6 +1715,8 @@ mod tests { new_fragments: vec![sample_file.clone()], fields_modified: vec![], mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap: vec![], + update_mode: None, }, Operation::Delete { updated_fragments: vec![apply_deletion(&[1], &mut fragment, &dataset).await], @@ -1721,6 +1729,8 @@ mod tests { new_fragments: vec![sample_file], fields_modified: vec![], mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap: vec![], + update_mode: None, }, ]; let transactions = @@ -1840,6 +1850,8 @@ mod tests { new_fragments: vec![sample_file.clone()], fields_modified: vec![], mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap: vec![], + update_mode: None, }, ), ( @@ -1850,6 +1862,8 @@ mod tests { new_fragments: vec![sample_file.clone()], fields_modified: vec![], mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap: vec![], + update_mode: None, }, ), ( @@ -2005,6 +2019,8 @@ mod tests { new_fragments: vec![fragment2.clone()], fields_modified: vec![0], mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap: vec![], + update_mode: None, }, Operation::UpdateConfig { upsert_values: Some(HashMap::from_iter(vec![( @@ -2197,6 +2213,8 @@ mod tests { new_fragments: vec![fragment2], fields_modified: vec![0], mem_wal_to_merge: None, + fields_for_preserving_frag_bitmap: vec![], + update_mode: None, }, [ Compatible, // append