diff --git a/rust/lance/src/dataset/delta.rs b/rust/lance/src/dataset/delta.rs index d14e6d55869..1ee94e2d4e3 100644 --- a/rust/lance/src/dataset/delta.rs +++ b/rust/lance/src/dataset/delta.rs @@ -319,16 +319,20 @@ impl DatasetDelta { ])?; // Filter for rows created in the version range - let (begin_version, end_version) = self.resolve_range().await?; - let filter = format!( - "_row_created_at_version > {} AND _row_created_at_version <= {}", - begin_version, end_version - ); + let filter = self.build_inserted_rows_filter().await?; scanner.filter(&filter)?; scanner.try_into_stream().await } + async fn build_inserted_rows_filter(&self) -> Result { + let (begin_version, end_version) = self.resolve_range().await?; + Ok(format!( + "_row_created_at_version > {} AND _row_created_at_version <= {}", + begin_version, end_version + )) + } + /// Get updated rows between the two versions. /// /// This returns rows where `_row_last_updated_at_version` is greater than `begin_version` @@ -373,15 +377,83 @@ impl DatasetDelta { ])?; // Filter for rows that were updated (not inserted) in the version range + let filter = self.build_updated_rows_batch_filter().await?; + scanner.filter(&filter)?; + + scanner.try_into_stream().await + } + + async fn build_updated_rows_batch_filter(&self) -> Result { let (begin_version, end_version) = self.resolve_range().await?; - let filter = format!( + Ok(format!( "_row_created_at_version <= {} AND _row_last_updated_at_version > {} AND _row_last_updated_at_version <= {}", begin_version, begin_version, end_version - ); + )) + } + + /// Get upserted rows between the two versions. + /// + /// This returns rows meet following conditions: + /// Condition 1: + /// `_row_last_updated_at_version` is greater than `begin_version` + /// and less than or equal to `end_version`, but `_row_created_at_version` is less than + /// or equal to `begin_version` (to exclude newly inserted rows). + /// Condition 2: + /// This returns rows where `_row_created_at_version` is greater than `begin_version` + /// and less than or equal to `end_version`. + /// + /// The result always includes: + /// - `_row_created_at_version`: Version when the row was created + /// - `_row_last_updated_at_version`: Version when the row was last updated + /// - `_rowid`: Row ID + /// - All other columns from the dataset + /// + /// # Returns + /// + /// A stream of record batches containing the updated and inserted rows. + /// + /// # Example + /// + /// ``` + /// # use lance::{Dataset, Result}; + /// # use futures::TryStreamExt; + /// # async fn example(dataset: &Dataset, previous_version: u64) -> Result<()> { + /// let delta = dataset.delta() + /// .compared_against_version(previous_version) + /// .build()?; + /// let mut updated = delta.get_upserted_rows().await?; + /// while let Some(batch) = updated.try_next().await? { + /// // Process batch... + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn get_upserted_rows(&self) -> Result { + let mut scanner = self.base_dataset.scan(); + + // Enable version columns + scanner.project(&[ + WILDCARD, + ROW_ID, + ROW_CREATED_AT_VERSION, + ROW_LAST_UPDATED_AT_VERSION, + ])?; + + // Filter for rows that were updated or inserted in the version range + let filter = self.build_upserted_rows_filter().await?; scanner.filter(&filter)?; scanner.try_into_stream().await } + + async fn build_upserted_rows_filter(&self) -> Result { + let inserted_row_filter = self.build_inserted_rows_filter().await?; + let updated_rows_filter = self.build_updated_rows_batch_filter().await?; + Ok(format!( + "({}) OR ({})", + inserted_row_filter, updated_rows_filter + )) + } } #[cfg(test)] @@ -1405,6 +1477,62 @@ mod tests { } } + #[tokio::test] + async fn test_get_upsert_rows() { + // Create initial dataset (version 1) + let temp_dir = lance_core::utils::tempfile::TempStrDir::default(); + let ds = write_dataset_temp(&temp_dir, 0, 50, 1, "value", true, false).await; + + assert_eq!(ds.version().version, 1); + + // Append inserted rows (version 2) + let ds = write_dataset_temp(&temp_dir, 50, 20, 1, "appended_v2", true, true).await; + assert_eq!(ds.version().version, 2); + + // Update some existing rows (version 3) + let ds = update_where(ds, "key < 10", "updated_v3").await; + assert_eq!(ds.version().version, 3); + + // Get upserted rows between version 1 and 3 + let delta = ds + .delta() + .with_begin_version(1) + .with_end_version(3) + .build() + .unwrap(); + + let stream = delta.get_upserted_rows().await.unwrap(); + let result = collect_stream(stream).await; + + // Should include 20 inserted rows (keys 50-69) and 10 updated rows (keys 0-9) + assert_eq!(result.num_rows(), 30); + assert!(result.column_by_name(ROW_ID).is_some()); + assert!(result.column_by_name(ROW_CREATED_AT_VERSION).is_some()); + assert!(result.column_by_name(ROW_LAST_UPDATED_AT_VERSION).is_some()); + + let created_at = result[ROW_CREATED_AT_VERSION] + .as_primitive::() + .values(); + let updated_at = result[ROW_LAST_UPDATED_AT_VERSION] + .as_primitive::() + .values(); + let keys = result["key"].as_primitive::().values(); + + for i in 0..result.num_rows() { + let key = keys[i]; + if key < 10 { + // Updated rows from version 3 + assert_eq!(created_at[i], 1); + assert_eq!(updated_at[i], 3); + } else { + // Inserted rows from version 2 + assert!((50..70).contains(&key)); + assert_eq!(created_at[i], 2); + assert_eq!(updated_at[i], 2); + } + } + } + #[tokio::test] async fn test_build_with_date_window_basic() { MockClock::set_system_time(std::time::Duration::from_secs(10));