Skip to content
142 changes: 135 additions & 7 deletions rust/lance/src/dataset/delta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,16 +319,20 @@ impl DatasetDelta {
])?;

// Filter for rows created in the version range
let (begin_version, end_version) = self.resolve_range().await?;
let filter = format!(
"_row_created_at_version > {} AND _row_created_at_version <= {}",
begin_version, end_version
);
let filter = self.build_inserted_rows_filter().await?;
scanner.filter(&filter)?;

scanner.try_into_stream().await
}

async fn build_inserted_rows_filter(&self) -> Result<String> {
let (begin_version, end_version) = self.resolve_range().await?;
Ok(format!(
"_row_created_at_version > {} AND _row_created_at_version <= {}",
begin_version, end_version
))
}

/// Get updated rows between the two versions.
///
/// This returns rows where `_row_last_updated_at_version` is greater than `begin_version`
Expand Down Expand Up @@ -373,15 +377,83 @@ impl DatasetDelta {
])?;

// Filter for rows that were updated (not inserted) in the version range
let filter = self.build_updated_rows_batch_filter().await?;
scanner.filter(&filter)?;

scanner.try_into_stream().await
}

async fn build_updated_rows_batch_filter(&self) -> Result<String> {
let (begin_version, end_version) = self.resolve_range().await?;
let filter = format!(
Ok(format!(
"_row_created_at_version <= {} AND _row_last_updated_at_version > {} AND _row_last_updated_at_version <= {}",
begin_version, begin_version, end_version
);
))
}

/// Get upserted rows between the two versions.
///
/// This returns rows meet following conditions:
/// Condition 1:
/// `_row_last_updated_at_version` is greater than `begin_version`
/// and less than or equal to `end_version`, but `_row_created_at_version` is less than
/// or equal to `begin_version` (to exclude newly inserted rows).
/// Condition 2:
/// This returns rows where `_row_created_at_version` is greater than `begin_version`
/// and less than or equal to `end_version`.
///
/// The result always includes:
/// - `_row_created_at_version`: Version when the row was created
/// - `_row_last_updated_at_version`: Version when the row was last updated
/// - `_rowid`: Row ID
/// - All other columns from the dataset
///
/// # Returns
///
/// A stream of record batches containing the updated and inserted rows.
///
/// # Example
///
/// ```
/// # use lance::{Dataset, Result};
/// # use futures::TryStreamExt;
/// # async fn example(dataset: &Dataset, previous_version: u64) -> Result<()> {
/// let delta = dataset.delta()
/// .compared_against_version(previous_version)
/// .build()?;
/// let mut updated = delta.get_upserted_rows().await?;
/// while let Some(batch) = updated.try_next().await? {
/// // Process batch...
/// }
/// # Ok(())
/// # }
/// ```
pub async fn get_upserted_rows(&self) -> Result<DatasetRecordBatchStream> {
let mut scanner = self.base_dataset.scan();

// Enable version columns
scanner.project(&[
WILDCARD,
ROW_ID,
ROW_CREATED_AT_VERSION,
ROW_LAST_UPDATED_AT_VERSION,
])?;

// Filter for rows that were updated or inserted in the version range
let filter = self.build_upserted_rows_filter().await?;
scanner.filter(&filter)?;

scanner.try_into_stream().await
}

async fn build_upserted_rows_filter(&self) -> Result<String> {
let inserted_row_filter = self.build_inserted_rows_filter().await?;
let updated_rows_filter = self.build_updated_rows_batch_filter().await?;
Ok(format!(
"({}) OR ({})",
inserted_row_filter, updated_rows_filter
))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -1405,6 +1477,62 @@ mod tests {
}
}

#[tokio::test]
async fn test_get_upsert_rows() {
// Create initial dataset (version 1)
let temp_dir = lance_core::utils::tempfile::TempStrDir::default();
let ds = write_dataset_temp(&temp_dir, 0, 50, 1, "value", true, false).await;

assert_eq!(ds.version().version, 1);

// Append inserted rows (version 2)
let ds = write_dataset_temp(&temp_dir, 50, 20, 1, "appended_v2", true, true).await;
assert_eq!(ds.version().version, 2);

// Update some existing rows (version 3)
let ds = update_where(ds, "key < 10", "updated_v3").await;
assert_eq!(ds.version().version, 3);

// Get upserted rows between version 1 and 3
let delta = ds
.delta()
.with_begin_version(1)
.with_end_version(3)
.build()
.unwrap();

let stream = delta.get_upserted_rows().await.unwrap();
let result = collect_stream(stream).await;

// Should include 20 inserted rows (keys 50-69) and 10 updated rows (keys 0-9)
assert_eq!(result.num_rows(), 30);
assert!(result.column_by_name(ROW_ID).is_some());
assert!(result.column_by_name(ROW_CREATED_AT_VERSION).is_some());
assert!(result.column_by_name(ROW_LAST_UPDATED_AT_VERSION).is_some());

let created_at = result[ROW_CREATED_AT_VERSION]
.as_primitive::<UInt64Type>()
.values();
let updated_at = result[ROW_LAST_UPDATED_AT_VERSION]
.as_primitive::<UInt64Type>()
.values();
let keys = result["key"].as_primitive::<Int32Type>().values();

for i in 0..result.num_rows() {
let key = keys[i];
if key < 10 {
// Updated rows from version 3
assert_eq!(created_at[i], 1);
assert_eq!(updated_at[i], 3);
} else {
// Inserted rows from version 2
assert!((50..70).contains(&key));
assert_eq!(created_at[i], 2);
assert_eq!(updated_at[i], 2);
}
}
}

#[tokio::test]
async fn test_build_with_date_window_basic() {
MockClock::set_system_time(std::time::Duration::from_secs(10));
Expand Down