From ba5126ddc65bed3ca17ba42a112636a179bb8521 Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Tue, 9 Sep 2025 20:52:00 +0800 Subject: [PATCH 1/9] feat(java): expose merge_insert api --- java/lance-jni/Cargo.toml | 2 + java/lance-jni/src/lib.rs | 1 + java/lance-jni/src/merge_insert.rs | 255 ++++++++++++++ .../main/java/com/lancedb/lance/Dataset.java | 33 ++ .../com/lancedb/lance/merge/MergeInsert.java | 311 ++++++++++++++++++ .../lance/merge/MergeInsertResult.java | 34 ++ .../lancedb/lance/merge/MergeInsertStats.java | 80 +++++ .../lance/operation/MergeInsertTest.java | 249 ++++++++++++++ rust/lance-datafusion/src/planner.rs | 2 +- 9 files changed, 966 insertions(+), 1 deletion(-) create mode 100644 java/lance-jni/src/merge_insert.rs create mode 100644 java/src/main/java/com/lancedb/lance/merge/MergeInsert.java create mode 100644 java/src/main/java/com/lancedb/lance/merge/MergeInsertResult.java create mode 100644 java/src/main/java/com/lancedb/lance/merge/MergeInsertStats.java create mode 100644 java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java diff --git a/java/lance-jni/Cargo.toml b/java/lance-jni/Cargo.toml index 16786c3f828..4f3a9a2d4dd 100644 --- a/java/lance-jni/Cargo.toml +++ b/java/lance-jni/Cargo.toml @@ -39,3 +39,5 @@ prost = "0.13.5" roaring = "0.10.1" prost-types = "0.13.5" chrono = "0.4.41" +datafusion-common = "48.0.0" +datafusion-sql = "48.0.0" diff --git a/java/lance-jni/src/lib.rs b/java/lance-jni/src/lib.rs index 689b2e3e41f..5f913bae896 100644 --- a/java/lance-jni/src/lib.rs +++ b/java/lance-jni/src/lib.rs @@ -46,6 +46,7 @@ pub mod ffi; mod file_reader; mod file_writer; mod fragment; +mod merge_insert; mod schema; mod sql; pub mod traits; diff --git a/java/lance-jni/src/merge_insert.rs b/java/lance-jni/src/merge_insert.rs new file mode 100644 index 00000000000..b119e7ef51a --- /dev/null +++ b/java/lance-jni/src/merge_insert.rs @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use crate::blocking_dataset::{BlockingDataset, NATIVE_DATASET}; +use crate::error::Result; +use crate::traits::{FromJString, IntoJava}; +use crate::{Error, JNIEnvExt, RT}; +use arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; +use arrow_schema::Schema; +use datafusion_common::DFSchema; +use datafusion_sql::parser::DFParserBuilder; +use datafusion_sql::planner::{PlannerContext, SqlToRel}; +use jni::objects::{JObject, JString, JValueGen}; +use jni::sys::jlong; +use jni::JNIEnv; +use lance::dataset::{ + MergeInsertBuilder, MergeStats, WhenMatched, WhenNotMatched, WhenNotMatchedBySource, +}; +use lance_datafusion::planner::LanceContextProvider; +use std::sync::Arc; +use std::time::Duration; + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeMergeInsert<'a>( + mut env: JNIEnv<'a>, + jdataset: JObject, + jparam: JObject, + batch_address: jlong, +) -> JObject<'a> { + ok_or_throw!( + env, + inner_merge_insert(&mut env, jdataset, jparam, batch_address) + ) +} + +#[allow(clippy::too_many_arguments)] +fn inner_merge_insert<'local>( + env: &mut JNIEnv<'local>, + jdataset: JObject, + jparam: JObject, + batch_address: jlong, +) -> Result> { + let on = extract_on(env, &jparam)?; + let when_matched = extract_when_matched(env, &jparam)?; + let when_not_matched = extract_when_not_matached(env, &jparam)?; + + let when_not_matched_by_source_str = extract_when_not_matched_by_source_str(env, &jparam)?; + let when_not_matched_by_source_delete_expr = + extract_when_not_matched_by_source_delete_expr(env, &jparam)?; + + let conflict_retries = extract_conflict_retries(env, &jparam)?; + let retry_timeout_ms = extract_retry_timeout_ms(env, &jparam)?; + let skip_auto_cleanup = extract_skip_auto_cleanup(env, &jparam)?; + + let (new_ds, merge_stats) = unsafe { + let dataset = env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET)?; + + let when_not_matched_by_source = extract_when_not_matched_by_source( + Schema::from(dataset.inner.schema()), + when_not_matched_by_source_str.as_str(), + when_not_matched_by_source_delete_expr.as_str(), + )?; + + let merge_insert_job = MergeInsertBuilder::try_new(Arc::new(dataset.clone().inner), on)? + .when_matched(when_matched) + .when_not_matched(when_not_matched) + .when_not_matched_by_source(when_not_matched_by_source) + .conflict_retries(conflict_retries) + .retry_timeout(Duration::from_millis(retry_timeout_ms as u64)) + .skip_auto_cleanup(skip_auto_cleanup) + .try_build()?; + + let stream_ptr = batch_address as *mut FFI_ArrowArrayStream; + let source_stream = ArrowArrayStreamReader::from_raw(stream_ptr)?; + + RT.block_on(async move { merge_insert_job.execute_reader(source_stream).await })? + }; + + Ok(MergeResult( + BlockingDataset { + inner: Arc::try_unwrap(new_ds).unwrap(), + }, + merge_stats, + ) + .into_java(env)?) +} + +fn extract_on<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result> { + let on: JObject = env + .call_method(jparam, "on", "()Ljava/util/List;", &[])? + .l()?; + env.get_strings(&on) +} + +fn extract_when_matched<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result { + let when_matched: JString = env + .call_method(jparam, "whenMatchedValue", "()Ljava/lang/String;", &[])? + .l()? + .into(); + let when_matched = when_matched.extract(env)?; + + let when_matched_update_expr: JString = env + .call_method(jparam, "whenMatchedUpdateExpr", "()Ljava/lang/String;", &[])? + .l()? + .into(); + let when_matched_update_expr = when_matched_update_expr.extract(env)?; + + match when_matched.as_str() { + "UpdateAll" => Ok(WhenMatched::UpdateAll), + "DoNothing" => Ok(WhenMatched::DoNothing), + "UpdateIf" => Ok(WhenMatched::UpdateIf(when_matched_update_expr)), + _ => Err(Error::input_error(format!( + "Illegal when_matched: {when_matched}", + ))), + } +} + +fn extract_when_not_matached<'local>( + env: &mut JNIEnv<'local>, + jparam: &JObject, +) -> Result { + let when_not_matched: JString = env + .call_method(jparam, "whenNotMatchedValue", "()Ljava/lang/String;", &[])? + .l()? + .into(); + let when_not_matched = when_not_matched.extract(env)?; + + match when_not_matched.as_str() { + "InsertAll" => Ok(WhenNotMatched::InsertAll), + "DoNothing" => Ok(WhenNotMatched::DoNothing), + _ => Err(Error::input_error(format!( + "Illegal when_not_matched: {when_not_matched}", + ))), + } +} + +fn extract_when_not_matched_by_source_str<'local>( + env: &mut JNIEnv<'local>, + jparam: &JObject, +) -> Result { + let when_not_matched_by_source: JString = env + .call_method( + jparam, + "whenNotMatchedBySourceValue", + "()Ljava/lang/String;", + &[], + )? + .l()? + .into(); + when_not_matched_by_source.extract(env) +} + +fn extract_when_not_matched_by_source_delete_expr<'local>( + env: &mut JNIEnv<'local>, + jparam: &JObject, +) -> Result { + let when_not_matched_by_source_delete_expr: JString = env + .call_method( + jparam, + "whenNotMatchedBySourceDeleteExpr", + "()Ljava/lang/String;", + &[], + )? + .l()? + .into(); + when_not_matched_by_source_delete_expr.extract(env) +} + +fn extract_when_not_matched_by_source<'local>( + schema: Schema, + when_not_matched_by_source: &str, + when_not_matched_by_source_delete_expr: &str, +) -> Result { + match when_not_matched_by_source { + "Keep" => Ok(WhenNotMatchedBySource::Keep), + "Delete" => Ok(WhenNotMatchedBySource::Delete), + "DeleteIf" => { + let sql_expr = DFParserBuilder::new(when_not_matched_by_source_delete_expr) + .build() + .unwrap() + .parser + .parse_expr() + .unwrap(); + + let expr = SqlToRel::new(&LanceContextProvider::default()) + .sql_to_expr( + sql_expr, + &DFSchema::try_from(schema).unwrap(), + &mut PlannerContext::default(), + ) + .unwrap(); + + Ok(WhenNotMatchedBySource::DeleteIf(expr)) + } + _ => Err(Error::input_error(format!( + "Illegal when_not_matched_by_source: {when_not_matched_by_source}", + ))), + } +} + +fn extract_conflict_retries<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result { + let retries = env + .call_method(jparam, "conflictRetries", "()I", &[])? + .i()? as u32; + Ok(retries) +} + +fn extract_retry_timeout_ms<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result { + let timeout_ms = env.call_method(jparam, "retryTimeoutMs", "()J", &[])?.j()? as u64; + Ok(timeout_ms) +} + +fn extract_skip_auto_cleanup<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result { + let skip_auto_cleanup = env + .call_method(jparam, "skipAutoCleanup", "()Z", &[])? + .z()?; + Ok(skip_auto_cleanup) +} + +const MERGE_STATS_CLASS: &str = "com/lancedb/lance/merge/MergeInsertStats"; +const MERGE_STATS_CONSTRUCTOR_SIG: &str = "(JJJIJJ)V"; +const MERGE_RESULT_CLASS: &str = "com/lancedb/lance/merge/MergeInsertResult"; +const MERGE_RESULT_CONSTRUCTOR_SIG: &str = + "(Lcom/lancedb/lance/Dataset;Lcom/lancedb/lance/merge/MergeInsertStats;)V"; + +impl IntoJava for MergeStats { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + Ok(env.new_object( + MERGE_STATS_CLASS, + MERGE_STATS_CONSTRUCTOR_SIG, + &[ + JValueGen::Long(self.num_inserted_rows as i64), + JValueGen::Long(self.num_updated_rows as i64), + JValueGen::Long(self.num_deleted_rows as i64), + JValueGen::Int(self.num_attempts as i32), + JValueGen::Long(self.bytes_written as i64), + JValueGen::Long(self.num_files_written as i64), + ], + )?) + } +} + +struct MergeResult(BlockingDataset, MergeStats); + +impl IntoJava for MergeResult { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let jdataset = self.0.into_java(env)?; + let jstats = self.1.into_java(env)?; + Ok(env.new_object( + MERGE_RESULT_CLASS, + MERGE_RESULT_CONSTRUCTOR_SIG, + &[JValueGen::Object(&jdataset), JValueGen::Object(&jstats)], + )?) + } +} diff --git a/java/src/main/java/com/lancedb/lance/Dataset.java b/java/src/main/java/com/lancedb/lance/Dataset.java index 423e1b234be..149bf3f0df6 100644 --- a/java/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/src/main/java/com/lancedb/lance/Dataset.java @@ -18,6 +18,8 @@ import com.lancedb.lance.ipc.DataStatistics; import com.lancedb.lance.ipc.LanceScanner; import com.lancedb.lance.ipc.ScanOptions; +import com.lancedb.lance.merge.MergeInsert; +import com.lancedb.lance.merge.MergeInsertResult; import com.lancedb.lance.schema.ColumnAlteration; import com.lancedb.lance.schema.LanceSchema; import com.lancedb.lance.schema.SqlExpressions; @@ -961,6 +963,37 @@ public SqlQuery sql(String sql) { return new SqlQuery(this, sql); } + /** + * Merge source data with the existing target data. + * + *

This will take in the source, merge it with the existing target data, and insert new rows, + * update existing rows, and delete existing rows. + * + *

It is important that after merge insert, the current dataset is changed and should be + * closed. The merged new dataset is contained in the MergeInsertResult. + * + * @param mergeInsert MergeInsert options + * @param source ArrowArrayStream source data + * @return MergeInsertResult containing the new merged Dataset. + */ + public MergeInsertResult mergeInsert(MergeInsert mergeInsert, ArrowArrayStream source) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + MergeInsertResult result = nativeMergeInsert(mergeInsert, source.memoryAddress()); + + Dataset newDataset = result.dataset(); + if (selfManagedAllocator) { + newDataset.allocator = new RootAllocator(Long.MAX_VALUE); + } else { + newDataset.allocator = allocator; + } + + return new MergeInsertResult(newDataset, result.stats()); + } + } + + private native MergeInsertResult nativeMergeInsert( + MergeInsert mergeInsert, long arrowStreamMemoryAddress); + private native void nativeCreateTag(String tag, long version); private native void nativeDeleteTag(String tag); diff --git a/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java b/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java new file mode 100644 index 00000000000..1ecd3c2b5de --- /dev/null +++ b/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java @@ -0,0 +1,311 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.merge; + +import com.lancedb.lance.Dataset; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; + +import java.util.List; + +public class MergeInsert { + private final List on; + + private WhenMatched whenMatched = WhenMatched.DoNothing; + private String whenMatchedUpdateExpr = ""; + + private WhenNotMatched whenNotMatched = WhenNotMatched.InsertAll; + + private WhenNotMatchedBySource whenNotMatchedBySource = WhenNotMatchedBySource.Keep; + private String whenNotMatchedBySourceDeleteExpr = ""; + + private int conflictRetries = 10; + private long retryTimeoutMs = 30 * 1000; + private boolean skipAutoCleanup = false; + + public MergeInsert(List on) { + this.on = on; + } + + /** + * Specify that when a row in the source table matches a row in the target table, the row is + * deleted from the target table and the matched row based on the source table is inserted. + * + *

This can be used to achieve upsert behavior. + * + * @return This MergeInsert instance + */ + public MergeInsert withMatchedUpdateAll() { + this.whenMatched = WhenMatched.UpdateAll; + return this; + } + + /** + * Specify that when a row in the source table matches a row in the target table, the row in the + * target table is kept unchanged. + * + *

This can be used to achieve find-or-create behavior. + * + * @return This MergeInsert instance + */ + public MergeInsert withMatchedDoNothing() { + this.whenMatched = WhenMatched.DoNothing; + return this; + } + + /** + * Specify that when a row in the source table matches a row in the target table and the + * expression evaluates to true, the row in the target table is updated by the matched row from + * the source table. + * + *

This can be used to achieve upsert behavior. + * + *

The expression can reference source tables' columns with source. and target + * tables' columns with target. This is an example: + * source.column1 = target.column1 AND source.column2 = target.column2 + * + * @param expr The expression to evaluate on the rows in the source table and target table. + * @return This MergeInsert instance + */ + public MergeInsert withMatchedUpdateIf(String expr) { + this.whenMatched = WhenMatched.UpdateIf; + this.whenMatchedUpdateExpr = expr; + return this; + } + + /** + * Specify what should happen when a source row has no match in the target. + * + * @param whenNotMatched The action to take when a source row has no match in the target. + * @return This MergeInsert instance + */ + public MergeInsert withNotMatched(WhenNotMatched whenNotMatched) { + this.whenNotMatched = whenNotMatched; + return this; + } + + /** + * Specify that when a target row has no match in the source, the row is kept in the target table. + * + * @return This MergeInsert instance + */ + public MergeInsert withNotMatchedBySourceKeep() { + this.whenNotMatchedBySource = WhenNotMatchedBySource.Keep; + return this; + } + + /** + * Specify that when a target row has no match in the source, the row is deleted from the target + * table. + * + * @return This MergeInsert instance + */ + public MergeInsert withNotMatchedBySourceDelete() { + this.whenNotMatchedBySource = WhenNotMatchedBySource.Delete; + return this; + } + + /** + * Specify that when a target row has no match in the source and the expression evaluates to true, + * the row is deleted from the target table. + * + * @param expr The expression to evaluate on the rows in the target table. + * @return This MergeInsert instance + */ + public MergeInsert withNotMatchedBySourceDeleteIf(String expr) { + this.whenNotMatchedBySource = WhenNotMatchedBySource.DeleteIf; + this.whenNotMatchedBySourceDeleteExpr = expr; + return this; + } + + /** + * Set number of times to retry the operation if there is contention. + * + *

If this is set > 0, then the operation will keep a copy of the input data either in memory + * or on disk (depending on the size of the data) and will retry the operation if there is + * contention. + * + *

Default is 10. + * + * @param retries Number of times to retry the operation if there is contention. + * @return This MergeInsert instance + */ + public MergeInsert withConflictRetries(int retries) { + this.conflictRetries = retries; + return this; + } + + /** + * Set the timeout in milliseconds used to limit retries. + * + *

This is the maximum time to spend on the operation before giving up. At least one attempt + * will be made, regardless of how long it takes to complete. Subsequent attempts will be + * cancelled once this timeout is reached. If the timeout has been reached during the first + * attempt, the operation will be cancelled immediately. + * + *

Default is 30000. + * + * @param timeoutMs Timeout in milliseconds used to limit retries. + * @return This MergeInsert instance + */ + public MergeInsert withRetryTimeoutMs(long timeoutMs) { + this.retryTimeoutMs = timeoutMs; + return this; + } + + /** + * If true, skip auto cleanup during commits. This should be set to true for high frequency writes + * to improve performance. This is also useful if the writer does not have delete permissions and + * the clean up would just try and log a failure anyway. + * + * @param skipAutoCleanup Whether to skip auto cleanup during commits. + * @return This MergeInsert instance + */ + public MergeInsert withSkipAutoCleanup(boolean skipAutoCleanup) { + this.skipAutoCleanup = skipAutoCleanup; + return this; + } + + public List on() { + return on; + } + + public WhenMatched whenMatched() { + return whenMatched; + } + + public String whenMatchedValue() { + return whenMatched.name(); + } + + public String whenMatchedUpdateExpr() { + return whenMatchedUpdateExpr; + } + + public WhenNotMatched whenNotMatched() { + return whenNotMatched; + } + + public String whenNotMatchedValue() { + return whenNotMatched.name(); + } + + public WhenNotMatchedBySource whenNotMatchedBySource() { + return whenNotMatchedBySource; + } + + public String whenNotMatchedBySourceValue() { + return whenNotMatchedBySource.name(); + } + + public String whenNotMatchedBySourceDeleteExpr() { + return whenNotMatchedBySourceDeleteExpr; + } + + public int conflictRetries() { + return conflictRetries; + } + + public long retryTimeoutMs() { + return retryTimeoutMs; + } + + public boolean skipAutoCleanup() { + return skipAutoCleanup; + } + + public MergeInsertStats execute(Dataset dataset, VectorSchemaRoot source) { + BufferAllocator allocator = dataset.allocator(); + try (ArrowArray ffiArrowArray = ArrowArray.allocateNew(allocator); + ArrowSchema ffiArrowSchema = ArrowSchema.allocateNew(allocator)) { + Data.exportVectorSchemaRoot(allocator, source, null, ffiArrowArray, ffiArrowSchema); + return nativeExecute( + dataset, + on, + whenMatched.name(), + whenMatchedUpdateExpr, + whenNotMatched.name(), + whenNotMatchedBySource.name(), + whenNotMatchedBySourceDeleteExpr, + conflictRetries, + retryTimeoutMs, + skipAutoCleanup, + ffiArrowArray.memoryAddress(), + ffiArrowSchema.memoryAddress()); + } + } + + private static native MergeInsertStats nativeExecute( + Dataset dataset, + List on, + String whenMatched, + String whenMatchedUpdateExpr, + String whenNotMatched, + String whenNotMatchedBySource, + String whenNotMatchedDeleteExpr, + int conflictRetries, + long retryTimeoutMs, + boolean skipAutoCleanup, + long batchMemoryAddress, + long schemaMemoryAddress); + + public enum WhenMatched { + /** + * The row is deleted from the target table and a new row is inserted based on the source table. + * This can be used to achieve upsert behavior. + */ + UpdateAll, + + /** The row is kept unchanged. This can be used to achieve find-or-create behavior. */ + DoNothing, + + /** + * The row is updated (similar to UpdateAll) only for rows where the expression evaluates to + * true. + */ + UpdateIf, + } + + public enum WhenNotMatched { + /** + * The new row is inserted into the target table. This is used in both find-or-create and upsert + * operations + */ + InsertAll, + + /** The new row is ignored. */ + DoNothing, + } + + public enum WhenNotMatchedBySource { + /** + * Do not delete rows from the target table This can be used for a find-or-create or an upsert + * operation + */ + Keep, + + /** Delete all rows from target table that don't match a row in the source table */ + Delete, + + /** + * Delete rows from the target table if there is no match AND the expression evaluates to true + * This can be used to replace a region of data with new data + */ + DeleteIf, + } +} diff --git a/java/src/main/java/com/lancedb/lance/merge/MergeInsertResult.java b/java/src/main/java/com/lancedb/lance/merge/MergeInsertResult.java new file mode 100644 index 00000000000..8e539ac9728 --- /dev/null +++ b/java/src/main/java/com/lancedb/lance/merge/MergeInsertResult.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.merge; + +import com.lancedb.lance.Dataset; + +public class MergeInsertResult { + private final Dataset dataset; + private final MergeInsertStats stats; + + public MergeInsertResult(Dataset dataset, MergeInsertStats stats) { + this.dataset = dataset; + this.stats = stats; + } + + public Dataset dataset() { + return dataset; + } + + public MergeInsertStats stats() { + return stats; + } +} diff --git a/java/src/main/java/com/lancedb/lance/merge/MergeInsertStats.java b/java/src/main/java/com/lancedb/lance/merge/MergeInsertStats.java new file mode 100644 index 00000000000..58b8c15574a --- /dev/null +++ b/java/src/main/java/com/lancedb/lance/merge/MergeInsertStats.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.merge; + +public final class MergeInsertStats { + private final long numInsertedRows; + private final long numUpdatedRows; + private final long numDeletedRows; + private final int numAttempts; + private final long bytesWritten; + private final long numFilesWritten; + + public MergeInsertStats( + long numInsertedRows, + long numUpdatedRows, + long numDeletedRows, + int numAttempts, + long bytesWritten, + long numFilesWritten) { + this.numInsertedRows = numInsertedRows; + this.numUpdatedRows = numUpdatedRows; + this.numDeletedRows = numDeletedRows; + this.numAttempts = numAttempts; + this.bytesWritten = bytesWritten; + this.numFilesWritten = numFilesWritten; + } + + public long numInsertedRows() { + return numInsertedRows; + } + + public long numUpdatedRows() { + return numUpdatedRows; + } + + public long numDeletedRows() { + return numDeletedRows; + } + + public int numAttempts() { + return numAttempts; + } + + public long bytesWritten() { + return bytesWritten; + } + + public long numFilesWritten() { + return numFilesWritten; + } + + @Override + public String toString() { + return "MergeInsertStats{" + + "numInsertedRows=" + + numInsertedRows + + ", numUpdatedRows=" + + numUpdatedRows + + ", numDeletedRows=" + + numDeletedRows + + ", numAttempts=" + + numAttempts + + ", bytesWritten=" + + bytesWritten + + ", numFilesWritten=" + + numFilesWritten + + '}'; + } +} diff --git a/java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java b/java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java new file mode 100644 index 00000000000..ebc56dbddaa --- /dev/null +++ b/java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java @@ -0,0 +1,249 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.operation; + +import com.lancedb.lance.Dataset; +import com.lancedb.lance.TestUtils; +import com.lancedb.lance.merge.MergeInsert; +import com.lancedb.lance.merge.MergeInsertResult; + +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.TreeMap; + +public class MergeInsertTest extends OperationTestBase { + + @Test + public void testWhenNotMatchedInsertAll(@TempDir Path tempDir) throws Exception { + // Test insert all unmatched source rows + + String datasetPath = tempDir.resolve("testWhenNotMatchedInsertAll").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + + int rowCount = 5; + try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { + + VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator); + ArrowArrayStream sourceStream = convertToStream(source, allocator); + MergeInsertResult result = + initialDataset.mergeInsert(new MergeInsert(Arrays.asList("id")), sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2, 3=Person 3, 4=Person 4, 7=Source 7, 8=Source 8, 9=Source 9}", + readAll(result.dataset()).toString()); + + sourceStream.close(); + source.close(); + } + } + } + + @Test + public void testWhenNotMatchedDoNothing(@TempDir Path tempDir) throws Exception { + // Test ignore unmatched source rows + + String datasetPath = tempDir.resolve("testWhenNotMatchedDoNothing").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + + int rowCount = 5; + try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { + + VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator); + ArrowArrayStream sourceStream = convertToStream(source, allocator); + MergeInsertResult result = + initialDataset.mergeInsert( + new MergeInsert(Arrays.asList("id")) + .withMatchedUpdateAll() + .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Source 0, 1=Source 1, 2=Source 2, 3=Person 3, 4=Person 4}", + readAll(result.dataset()).toString()); + + sourceStream.close(); + source.close(); + } + } + } + + @Test + public void testWhenMatchedUpdateIf(@TempDir Path tempDir) throws Exception { + // Test update matched rows if expression is true + + String datasetPath = tempDir.resolve("testWhenMatchedUpdateIf").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + + int rowCount = 5; + try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { + + VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator); + ArrowArrayStream sourceStream = convertToStream(source, allocator); + MergeInsertResult result = + initialDataset.mergeInsert( + new MergeInsert(Arrays.asList("id")) + .withMatchedUpdateIf("target.name = 'Person 0' or target.name = 'Person 1'") + .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Source 0, 1=Source 1, 2=Person 2, 3=Person 3, 4=Person 4}", + readAll(result.dataset()).toString()); + + sourceStream.close(); + source.close(); + } + } + } + + @Test + public void testWhenNotMatchedBySourceDelete(@TempDir Path tempDir) throws Exception { + // Test delete target rows which are not matched with source. + + String datasetPath = tempDir.resolve("testWhenNotMatchedBySourceDelete").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + + int rowCount = 5; + try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { + + VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator); + ArrowArrayStream sourceStream = convertToStream(source, allocator); + MergeInsertResult result = + initialDataset.mergeInsert( + new MergeInsert(Arrays.asList("id")) + .withNotMatchedBySourceDelete() + .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2}", readAll(result.dataset()).toString()); + + sourceStream.close(); + source.close(); + } + } + } + + @Test + public void testWhenNotMatchedBySourceDeleteIf(@TempDir Path tempDir) throws Exception { + // Test delete target rows which are not matched with source if expression is true + + String datasetPath = tempDir.resolve("testWhenNotMatchedBySourceDeleteIf").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + + int rowCount = 5; + try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { + + VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator); + ArrowArrayStream sourceStream = convertToStream(source, allocator); + MergeInsertResult result = + initialDataset.mergeInsert( + new MergeInsert(Arrays.asList("id")) + .withNotMatchedBySourceDeleteIf("name = 'Person 3'") + .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2, 4=Person 4}", + readAll(result.dataset()).toString()); + + sourceStream.close(); + source.close(); + } + } + } + + private VectorSchemaRoot buildSource(Schema schema, RootAllocator allocator) throws Exception { + List sourceIds = Arrays.asList(0, 1, 2, 7, 8, 9); + + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + root.allocateNew(); + + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + + for (int i = 0; i < sourceIds.size(); i++) { + idVector.setSafe(i, sourceIds.get(i)); + String name = "Source " + sourceIds.get(i); + nameVector.setSafe(i, name.getBytes(StandardCharsets.UTF_8)); + } + + root.setRowCount(sourceIds.size()); + + return root; + } + + private ArrowArrayStream convertToStream(VectorSchemaRoot root, RootAllocator allocator) + throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + ArrowStreamReader reader = new ArrowStreamReader(in, allocator); + + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator); + Data.exportArrayStream(allocator, reader, stream); + + return stream; + } + + private TreeMap readAll(Dataset dataset) throws Exception { + try (ArrowReader reader = dataset.newScan().scanBatches()) { + TreeMap map = new TreeMap<>(); + + while (reader.loadNextBatch()) { + VectorSchemaRoot batch = reader.getVectorSchemaRoot(); + for (int i = 0; i < batch.getRowCount(); i++) { + IntVector idVector = (IntVector) batch.getVector("id"); + VarCharVector nameVector = (VarCharVector) batch.getVector("name"); + map.put(idVector.get(i), new String(nameVector.get(i))); + } + } + + return map; + } + } +} diff --git a/rust/lance-datafusion/src/planner.rs b/rust/lance-datafusion/src/planner.rs index ddbce5e5b85..43f88fe03b2 100644 --- a/rust/lance-datafusion/src/planner.rs +++ b/rust/lance-datafusion/src/planner.rs @@ -152,7 +152,7 @@ impl ScalarUDFImpl for CastListF16Udf { } // Adapter that instructs datafusion how lance expects expressions to be interpreted -struct LanceContextProvider { +pub struct LanceContextProvider { options: datafusion::config::ConfigOptions, state: SessionState, expr_planners: Vec>, From f20805698ee28204c702149219aeeb7517c51829 Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Wed, 10 Sep 2025 19:16:28 +0800 Subject: [PATCH 2/9] Update test cases --- java/lance-jni/Cargo.toml | 4 +- .../lance/operation/MergeInsertTest.java | 135 +++++++++--------- 2 files changed, 67 insertions(+), 72 deletions(-) diff --git a/java/lance-jni/Cargo.toml b/java/lance-jni/Cargo.toml index 4f3a9a2d4dd..45d473f2118 100644 --- a/java/lance-jni/Cargo.toml +++ b/java/lance-jni/Cargo.toml @@ -39,5 +39,5 @@ prost = "0.13.5" roaring = "0.10.1" prost-types = "0.13.5" chrono = "0.4.41" -datafusion-common = "48.0.0" -datafusion-sql = "48.0.0" +datafusion-common = "49.0.2" +datafusion-sql = "49.0.2" diff --git a/java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java b/java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java index ebc56dbddaa..12eed70c931 100644 --- a/java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java +++ b/java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java @@ -54,17 +54,16 @@ public void testWhenNotMatchedInsertAll(@TempDir Path tempDir) throws Exception int rowCount = 5; try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { - VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator); - ArrowArrayStream sourceStream = convertToStream(source, allocator); - MergeInsertResult result = - initialDataset.mergeInsert(new MergeInsert(Arrays.asList("id")), sourceStream); - - Assertions.assertEquals( - "{0=Person 0, 1=Person 1, 2=Person 2, 3=Person 3, 4=Person 4, 7=Source 7, 8=Source 8, 9=Source 9}", - readAll(result.dataset()).toString()); - - sourceStream.close(); - source.close(); + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + initialDataset.mergeInsert(new MergeInsert(Arrays.asList("id")), sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2, 3=Person 3, 4=Person 4, 7=Source 7, 8=Source 8, 9=Source 9}", + readAll(result.dataset()).toString()); + } + } } } } @@ -81,21 +80,20 @@ public void testWhenNotMatchedDoNothing(@TempDir Path tempDir) throws Exception int rowCount = 5; try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { - VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator); - ArrowArrayStream sourceStream = convertToStream(source, allocator); - MergeInsertResult result = - initialDataset.mergeInsert( - new MergeInsert(Arrays.asList("id")) - .withMatchedUpdateAll() - .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), - sourceStream); - - Assertions.assertEquals( - "{0=Source 0, 1=Source 1, 2=Source 2, 3=Person 3, 4=Person 4}", - readAll(result.dataset()).toString()); - - sourceStream.close(); - source.close(); + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + initialDataset.mergeInsert( + new MergeInsert(Arrays.asList("id")) + .withMatchedUpdateAll() + .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Source 0, 1=Source 1, 2=Source 2, 3=Person 3, 4=Person 4}", + readAll(result.dataset()).toString()); + } + } } } } @@ -112,21 +110,20 @@ public void testWhenMatchedUpdateIf(@TempDir Path tempDir) throws Exception { int rowCount = 5; try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { - VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator); - ArrowArrayStream sourceStream = convertToStream(source, allocator); - MergeInsertResult result = - initialDataset.mergeInsert( - new MergeInsert(Arrays.asList("id")) - .withMatchedUpdateIf("target.name = 'Person 0' or target.name = 'Person 1'") - .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), - sourceStream); - - Assertions.assertEquals( - "{0=Source 0, 1=Source 1, 2=Person 2, 3=Person 3, 4=Person 4}", - readAll(result.dataset()).toString()); - - sourceStream.close(); - source.close(); + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + initialDataset.mergeInsert( + new MergeInsert(Arrays.asList("id")) + .withMatchedUpdateIf("target.name = 'Person 0' or target.name = 'Person 1'") + .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Source 0, 1=Source 1, 2=Person 2, 3=Person 3, 4=Person 4}", + readAll(result.dataset()).toString()); + } + } } } } @@ -143,20 +140,19 @@ public void testWhenNotMatchedBySourceDelete(@TempDir Path tempDir) throws Excep int rowCount = 5; try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { - VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator); - ArrowArrayStream sourceStream = convertToStream(source, allocator); - MergeInsertResult result = - initialDataset.mergeInsert( - new MergeInsert(Arrays.asList("id")) - .withNotMatchedBySourceDelete() - .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), - sourceStream); - - Assertions.assertEquals( - "{0=Person 0, 1=Person 1, 2=Person 2}", readAll(result.dataset()).toString()); - - sourceStream.close(); - source.close(); + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + initialDataset.mergeInsert( + new MergeInsert(Arrays.asList("id")) + .withNotMatchedBySourceDelete() + .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2}", readAll(result.dataset()).toString()); + } + } } } } @@ -173,21 +169,20 @@ public void testWhenNotMatchedBySourceDeleteIf(@TempDir Path tempDir) throws Exc int rowCount = 5; try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { - VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator); - ArrowArrayStream sourceStream = convertToStream(source, allocator); - MergeInsertResult result = - initialDataset.mergeInsert( - new MergeInsert(Arrays.asList("id")) - .withNotMatchedBySourceDeleteIf("name = 'Person 3'") - .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), - sourceStream); - - Assertions.assertEquals( - "{0=Person 0, 1=Person 1, 2=Person 2, 4=Person 4}", - readAll(result.dataset()).toString()); - - sourceStream.close(); - source.close(); + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + initialDataset.mergeInsert( + new MergeInsert(Arrays.asList("id")) + .withNotMatchedBySourceDeleteIf("name = 'Person 3'") + .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2, 4=Person 4}", + readAll(result.dataset()).toString()); + } + } } } } From bbb1f1d4118c089ff414d06062e01c2deac9eda6 Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Wed, 10 Sep 2025 19:32:11 +0800 Subject: [PATCH 3/9] fix clippy issue --- java/lance-jni/src/merge_insert.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/java/lance-jni/src/merge_insert.rs b/java/lance-jni/src/merge_insert.rs index b119e7ef51a..6ae5cf45149 100644 --- a/java/lance-jni/src/merge_insert.rs +++ b/java/lance-jni/src/merge_insert.rs @@ -76,13 +76,13 @@ fn inner_merge_insert<'local>( RT.block_on(async move { merge_insert_job.execute_reader(source_stream).await })? }; - Ok(MergeResult( + MergeResult( BlockingDataset { inner: Arc::try_unwrap(new_ds).unwrap(), }, merge_stats, ) - .into_java(env)?) + .into_java(env) } fn extract_on<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result> { @@ -166,7 +166,7 @@ fn extract_when_not_matched_by_source_delete_expr<'local>( when_not_matched_by_source_delete_expr.extract(env) } -fn extract_when_not_matched_by_source<'local>( +fn extract_when_not_matched_by_source( schema: Schema, when_not_matched_by_source: &str, when_not_matched_by_source_delete_expr: &str, From 4f70c2fa6e7436bd01fc18cc7776c8cc9dada5aa Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Wed, 10 Sep 2025 19:49:31 +0800 Subject: [PATCH 4/9] fix comments --- java/src/main/java/com/lancedb/lance/merge/MergeInsert.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java b/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java index 1ecd3c2b5de..cfaef4c6206 100644 --- a/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java +++ b/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java @@ -136,9 +136,9 @@ public MergeInsert withNotMatchedBySourceDeleteIf(String expr) { /** * Set number of times to retry the operation if there is contention. * - *

If this is set > 0, then the operation will keep a copy of the input data either in memory - * or on disk (depending on the size of the data) and will retry the operation if there is - * contention. + *

If this is set greater than 0, then the operation will keep a copy of the input data either + * in memory or on disk (depending on the size of the data) and will retry the operation if there + * is contention. * *

Default is 10. * From ed4792e518c73f9b3392e13dc329aae6ad172eaa Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Wed, 10 Sep 2025 19:50:43 +0800 Subject: [PATCH 5/9] Remove unused code --- .../com/lancedb/lance/merge/MergeInsert.java | 43 ------------------- 1 file changed, 43 deletions(-) diff --git a/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java b/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java index cfaef4c6206..065ebe4ee3b 100644 --- a/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java +++ b/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java @@ -13,14 +13,6 @@ */ package com.lancedb.lance.merge; -import com.lancedb.lance.Dataset; - -import org.apache.arrow.c.ArrowArray; -import org.apache.arrow.c.ArrowSchema; -import org.apache.arrow.c.Data; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; - import java.util.List; public class MergeInsert { @@ -229,41 +221,6 @@ public boolean skipAutoCleanup() { return skipAutoCleanup; } - public MergeInsertStats execute(Dataset dataset, VectorSchemaRoot source) { - BufferAllocator allocator = dataset.allocator(); - try (ArrowArray ffiArrowArray = ArrowArray.allocateNew(allocator); - ArrowSchema ffiArrowSchema = ArrowSchema.allocateNew(allocator)) { - Data.exportVectorSchemaRoot(allocator, source, null, ffiArrowArray, ffiArrowSchema); - return nativeExecute( - dataset, - on, - whenMatched.name(), - whenMatchedUpdateExpr, - whenNotMatched.name(), - whenNotMatchedBySource.name(), - whenNotMatchedBySourceDeleteExpr, - conflictRetries, - retryTimeoutMs, - skipAutoCleanup, - ffiArrowArray.memoryAddress(), - ffiArrowSchema.memoryAddress()); - } - } - - private static native MergeInsertStats nativeExecute( - Dataset dataset, - List on, - String whenMatched, - String whenMatchedUpdateExpr, - String whenNotMatched, - String whenNotMatchedBySource, - String whenNotMatchedDeleteExpr, - int conflictRetries, - long retryTimeoutMs, - boolean skipAutoCleanup, - long batchMemoryAddress, - long schemaMemoryAddress); - public enum WhenMatched { /** * The row is deleted from the target table and a new row is inserted based on the source table. From 018aeb83eb112fd0a699e572526b0a8a0069c2c2 Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Thu, 11 Sep 2025 14:16:58 +0800 Subject: [PATCH 6/9] Fix review issues --- java/lance-jni/Cargo.toml | 2 - java/lance-jni/src/merge_insert.rs | 94 ++++--- .../main/java/com/lancedb/lance/Dataset.java | 10 +- ...ergeInsert.java => MergeInsertParams.java} | 106 +++++--- .../lancedb/lance/merge/MergeInsertStats.java | 24 +- .../com/lancedb/lance/MergeInsertTest.java | 217 ++++++++++++++++ .../lance/operation/MergeInsertTest.java | 244 ------------------ 7 files changed, 361 insertions(+), 336 deletions(-) rename java/src/main/java/com/lancedb/lance/merge/{MergeInsert.java => MergeInsertParams.java} (65%) create mode 100644 java/src/test/java/com/lancedb/lance/MergeInsertTest.java delete mode 100644 java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java diff --git a/java/lance-jni/Cargo.toml b/java/lance-jni/Cargo.toml index 45d473f2118..16786c3f828 100644 --- a/java/lance-jni/Cargo.toml +++ b/java/lance-jni/Cargo.toml @@ -39,5 +39,3 @@ prost = "0.13.5" roaring = "0.10.1" prost-types = "0.13.5" chrono = "0.4.41" -datafusion-common = "49.0.2" -datafusion-sql = "49.0.2" diff --git a/java/lance-jni/src/merge_insert.rs b/java/lance-jni/src/merge_insert.rs index 6ae5cf45149..b1ff490444f 100644 --- a/java/lance-jni/src/merge_insert.rs +++ b/java/lance-jni/src/merge_insert.rs @@ -6,26 +6,23 @@ use crate::error::Result; use crate::traits::{FromJString, IntoJava}; use crate::{Error, JNIEnvExt, RT}; use arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; -use arrow_schema::Schema; -use datafusion_common::DFSchema; -use datafusion_sql::parser::DFParserBuilder; -use datafusion_sql::planner::{PlannerContext, SqlToRel}; use jni::objects::{JObject, JString, JValueGen}; use jni::sys::jlong; use jni::JNIEnv; +use lance::dataset::scanner::LanceFilter; use lance::dataset::{ MergeInsertBuilder, MergeStats, WhenMatched, WhenNotMatched, WhenNotMatchedBySource, }; -use lance_datafusion::planner::LanceContextProvider; +use lance_core::datatypes::Schema; use std::sync::Arc; use std::time::Duration; #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeMergeInsert<'a>( mut env: JNIEnv<'a>, - jdataset: JObject, - jparam: JObject, - batch_address: jlong, + jdataset: JObject, // Dataset object + jparam: JObject, // MergeInsertParams object + batch_address: jlong, // ArrowArrayStream address for source ) -> JObject<'a> { ok_or_throw!( env, @@ -56,9 +53,9 @@ fn inner_merge_insert<'local>( let dataset = env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET)?; let when_not_matched_by_source = extract_when_not_matched_by_source( - Schema::from(dataset.inner.schema()), + dataset.inner.schema(), when_not_matched_by_source_str.as_str(), - when_not_matched_by_source_delete_expr.as_str(), + when_not_matched_by_source_delete_expr, )?; let merge_insert_job = MergeInsertBuilder::try_new(Arc::new(dataset.clone().inner), on)? @@ -99,16 +96,23 @@ fn extract_when_matched<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> R .into(); let when_matched = when_matched.extract(env)?; - let when_matched_update_expr: JString = env - .call_method(jparam, "whenMatchedUpdateExpr", "()Ljava/lang/String;", &[])? - .l()? - .into(); - let when_matched_update_expr = when_matched_update_expr.extract(env)?; + let when_matched_update_expr = env + .call_method( + jparam, + "whenMatchedUpdateExpr", + "()Ljava/util/Optional;", + &[], + )? + .l()?; + let when_matched_update_expr = env.get_string_opt(&when_matched_update_expr)?; match when_matched.as_str() { "UpdateAll" => Ok(WhenMatched::UpdateAll), "DoNothing" => Ok(WhenMatched::DoNothing), - "UpdateIf" => Ok(WhenMatched::UpdateIf(when_matched_update_expr)), + "UpdateIf" => match when_matched_update_expr { + Some(expr) => Ok(WhenMatched::UpdateIf(expr)), + None => Err(Error::input_error("No matched updated expr".to_string())), + }, _ => Err(Error::input_error(format!( "Illegal when_matched: {when_matched}", ))), @@ -153,45 +157,51 @@ fn extract_when_not_matched_by_source_str<'local>( fn extract_when_not_matched_by_source_delete_expr<'local>( env: &mut JNIEnv<'local>, jparam: &JObject, -) -> Result { - let when_not_matched_by_source_delete_expr: JString = env +) -> Result> { + let when_not_matched_by_source_delete_expr = env .call_method( jparam, "whenNotMatchedBySourceDeleteExpr", - "()Ljava/lang/String;", + "()Ljava/util/Optional;", &[], )? - .l()? - .into(); - when_not_matched_by_source_delete_expr.extract(env) + .l()?; + + if let Some(expr) = env.get_string_opt(&when_not_matched_by_source_delete_expr)? { + return Ok(Some(LanceFilter::Sql(expr))); + } + + let when_not_matched_by_source_delete_substrait_expr = env + .call_method( + jparam, + "whenNotMatchedBySourceDeleteSubstraitExpr", + "()Ljava/util/Optional;", + &[], + )? + .l()?; + + match env.get_bytes_opt(&when_not_matched_by_source_delete_substrait_expr)? { + Some(expr) => Ok(Some(LanceFilter::Substrait(expr.to_vec()))), + None => Ok(None), + } } fn extract_when_not_matched_by_source( - schema: Schema, + schema: &Schema, when_not_matched_by_source: &str, - when_not_matched_by_source_delete_expr: &str, + when_not_matched_by_source_delete_expr: Option, ) -> Result { match when_not_matched_by_source { "Keep" => Ok(WhenNotMatchedBySource::Keep), "Delete" => Ok(WhenNotMatchedBySource::Delete), - "DeleteIf" => { - let sql_expr = DFParserBuilder::new(when_not_matched_by_source_delete_expr) - .build() - .unwrap() - .parser - .parse_expr() - .unwrap(); - - let expr = SqlToRel::new(&LanceContextProvider::default()) - .sql_to_expr( - sql_expr, - &DFSchema::try_from(schema).unwrap(), - &mut PlannerContext::default(), - ) - .unwrap(); - - Ok(WhenNotMatchedBySource::DeleteIf(expr)) - } + "DeleteIf" => match when_not_matched_by_source_delete_expr { + Some(expr) => Ok(WhenNotMatchedBySource::DeleteIf( + expr.to_datafusion(schema, schema)?, + )), + None => Err(Error::input_error(format!( + "No delete expr when not matched by source is: {when_not_matched_by_source}", + ))), + }, _ => Err(Error::input_error(format!( "Illegal when_not_matched_by_source: {when_not_matched_by_source}", ))), diff --git a/java/src/main/java/com/lancedb/lance/Dataset.java b/java/src/main/java/com/lancedb/lance/Dataset.java index 149bf3f0df6..2dd508c5e1a 100644 --- a/java/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/src/main/java/com/lancedb/lance/Dataset.java @@ -18,7 +18,7 @@ import com.lancedb.lance.ipc.DataStatistics; import com.lancedb.lance.ipc.LanceScanner; import com.lancedb.lance.ipc.ScanOptions; -import com.lancedb.lance.merge.MergeInsert; +import com.lancedb.lance.merge.MergeInsertParams; import com.lancedb.lance.merge.MergeInsertResult; import com.lancedb.lance.schema.ColumnAlteration; import com.lancedb.lance.schema.LanceSchema; @@ -972,11 +972,11 @@ public SqlQuery sql(String sql) { *

It is important that after merge insert, the current dataset is changed and should be * closed. The merged new dataset is contained in the MergeInsertResult. * - * @param mergeInsert MergeInsert options + * @param mergeInsert merge insert options * @param source ArrowArrayStream source data * @return MergeInsertResult containing the new merged Dataset. */ - public MergeInsertResult mergeInsert(MergeInsert mergeInsert, ArrowArrayStream source) { + public MergeInsertResult mergeInsert(MergeInsertParams mergeInsert, ArrowArrayStream source) { try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { MergeInsertResult result = nativeMergeInsert(mergeInsert, source.memoryAddress()); @@ -987,12 +987,12 @@ public MergeInsertResult mergeInsert(MergeInsert mergeInsert, ArrowArrayStream s newDataset.allocator = allocator; } - return new MergeInsertResult(newDataset, result.stats()); + return result; } } private native MergeInsertResult nativeMergeInsert( - MergeInsert mergeInsert, long arrowStreamMemoryAddress); + MergeInsertParams mergeInsert, long arrowStreamMemoryAddress); private native void nativeCreateTag(String tag, long version); diff --git a/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java b/java/src/main/java/com/lancedb/lance/merge/MergeInsertParams.java similarity index 65% rename from java/src/main/java/com/lancedb/lance/merge/MergeInsert.java rename to java/src/main/java/com/lancedb/lance/merge/MergeInsertParams.java index 065ebe4ee3b..9f63a5e3baa 100644 --- a/java/src/main/java/com/lancedb/lance/merge/MergeInsert.java +++ b/java/src/main/java/com/lancedb/lance/merge/MergeInsertParams.java @@ -13,24 +13,30 @@ */ package com.lancedb.lance.merge; +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; + +import java.nio.ByteBuffer; import java.util.List; +import java.util.Optional; -public class MergeInsert { +public class MergeInsertParams { private final List on; private WhenMatched whenMatched = WhenMatched.DoNothing; - private String whenMatchedUpdateExpr = ""; + private Optional whenMatchedUpdateExpr = Optional.empty(); private WhenNotMatched whenNotMatched = WhenNotMatched.InsertAll; private WhenNotMatchedBySource whenNotMatchedBySource = WhenNotMatchedBySource.Keep; - private String whenNotMatchedBySourceDeleteExpr = ""; + private Optional whenNotMatchedBySourceDeleteExpr = Optional.empty(); + private Optional whenNotMatchedBySourceDeleteSubstraitExpr = Optional.empty(); private int conflictRetries = 10; private long retryTimeoutMs = 30 * 1000; private boolean skipAutoCleanup = false; - public MergeInsert(List on) { + public MergeInsertParams(List on) { this.on = on; } @@ -40,9 +46,9 @@ public MergeInsert(List on) { * *

This can be used to achieve upsert behavior. * - * @return This MergeInsert instance + * @return This MergeInsertParams instance */ - public MergeInsert withMatchedUpdateAll() { + public MergeInsertParams withMatchedUpdateAll() { this.whenMatched = WhenMatched.UpdateAll; return this; } @@ -53,9 +59,9 @@ public MergeInsert withMatchedUpdateAll() { * *

This can be used to achieve find-or-create behavior. * - * @return This MergeInsert instance + * @return This MergeInsertParams instance */ - public MergeInsert withMatchedDoNothing() { + public MergeInsertParams withMatchedDoNothing() { this.whenMatched = WhenMatched.DoNothing; return this; } @@ -72,11 +78,12 @@ public MergeInsert withMatchedDoNothing() { * source.column1 = target.column1 AND source.column2 = target.column2 * * @param expr The expression to evaluate on the rows in the source table and target table. - * @return This MergeInsert instance + * @return This MergeInsertParams instance */ - public MergeInsert withMatchedUpdateIf(String expr) { + public MergeInsertParams withMatchedUpdateIf(String expr) { + Preconditions.checkNotNull(expr); this.whenMatched = WhenMatched.UpdateIf; - this.whenMatchedUpdateExpr = expr; + this.whenMatchedUpdateExpr = Optional.of(expr); return this; } @@ -84,9 +91,9 @@ public MergeInsert withMatchedUpdateIf(String expr) { * Specify what should happen when a source row has no match in the target. * * @param whenNotMatched The action to take when a source row has no match in the target. - * @return This MergeInsert instance + * @return This MergeInsertParams instance */ - public MergeInsert withNotMatched(WhenNotMatched whenNotMatched) { + public MergeInsertParams withNotMatched(WhenNotMatched whenNotMatched) { this.whenNotMatched = whenNotMatched; return this; } @@ -94,9 +101,9 @@ public MergeInsert withNotMatched(WhenNotMatched whenNotMatched) { /** * Specify that when a target row has no match in the source, the row is kept in the target table. * - * @return This MergeInsert instance + * @return This MergeInsertParams instance */ - public MergeInsert withNotMatchedBySourceKeep() { + public MergeInsertParams withNotMatchedBySourceKeep() { this.whenNotMatchedBySource = WhenNotMatchedBySource.Keep; return this; } @@ -105,9 +112,9 @@ public MergeInsert withNotMatchedBySourceKeep() { * Specify that when a target row has no match in the source, the row is deleted from the target * table. * - * @return This MergeInsert instance + * @return This MergeInsertParams instance */ - public MergeInsert withNotMatchedBySourceDelete() { + public MergeInsertParams withNotMatchedBySourceDelete() { this.whenNotMatchedBySource = WhenNotMatchedBySource.Delete; return this; } @@ -116,12 +123,29 @@ public MergeInsert withNotMatchedBySourceDelete() { * Specify that when a target row has no match in the source and the expression evaluates to true, * the row is deleted from the target table. * - * @param expr The expression to evaluate on the rows in the target table. - * @return This MergeInsert instance + * @param expr The sql expression to evaluate on the rows in the target table. + * @return This MergeInsertParams instance + */ + public MergeInsertParams withNotMatchedBySourceDeleteIf(String expr) { + Preconditions.checkNotNull(expr); + this.whenNotMatchedBySource = WhenNotMatchedBySource.DeleteIf; + this.whenNotMatchedBySourceDeleteExpr = Optional.of(expr); + this.whenNotMatchedBySourceDeleteSubstraitExpr = Optional.empty(); + return this; + } + + /** + * Specify that when a target row has no match in the source and the expression evaluates to true, + * the row is deleted from the target table. + * + * @param expr The substrait expression to evaluate on the rows in the target table. + * @return This MergeInsertParams instance */ - public MergeInsert withNotMatchedBySourceDeleteIf(String expr) { + public MergeInsertParams withNotMatchedBySourceDeleteSubstraitIf(ByteBuffer expr) { + Preconditions.checkNotNull(expr); this.whenNotMatchedBySource = WhenNotMatchedBySource.DeleteIf; - this.whenNotMatchedBySourceDeleteExpr = expr; + this.whenNotMatchedBySourceDeleteExpr = Optional.empty(); + this.whenNotMatchedBySourceDeleteSubstraitExpr = Optional.of(expr); return this; } @@ -135,9 +159,9 @@ public MergeInsert withNotMatchedBySourceDeleteIf(String expr) { *

Default is 10. * * @param retries Number of times to retry the operation if there is contention. - * @return This MergeInsert instance + * @return This MergeInsertParams instance */ - public MergeInsert withConflictRetries(int retries) { + public MergeInsertParams withConflictRetries(int retries) { this.conflictRetries = retries; return this; } @@ -153,9 +177,9 @@ public MergeInsert withConflictRetries(int retries) { *

Default is 30000. * * @param timeoutMs Timeout in milliseconds used to limit retries. - * @return This MergeInsert instance + * @return This MergeInsertParams instance */ - public MergeInsert withRetryTimeoutMs(long timeoutMs) { + public MergeInsertParams withRetryTimeoutMs(long timeoutMs) { this.retryTimeoutMs = timeoutMs; return this; } @@ -166,9 +190,9 @@ public MergeInsert withRetryTimeoutMs(long timeoutMs) { * the clean up would just try and log a failure anyway. * * @param skipAutoCleanup Whether to skip auto cleanup during commits. - * @return This MergeInsert instance + * @return This MergeInsertParams instance */ - public MergeInsert withSkipAutoCleanup(boolean skipAutoCleanup) { + public MergeInsertParams withSkipAutoCleanup(boolean skipAutoCleanup) { this.skipAutoCleanup = skipAutoCleanup; return this; } @@ -185,7 +209,7 @@ public String whenMatchedValue() { return whenMatched.name(); } - public String whenMatchedUpdateExpr() { + public Optional whenMatchedUpdateExpr() { return whenMatchedUpdateExpr; } @@ -205,10 +229,14 @@ public String whenNotMatchedBySourceValue() { return whenNotMatchedBySource.name(); } - public String whenNotMatchedBySourceDeleteExpr() { + public Optional whenNotMatchedBySourceDeleteExpr() { return whenNotMatchedBySourceDeleteExpr; } + public Optional whenNotMatchedBySourceDeleteSubstraitExpr() { + return whenNotMatchedBySourceDeleteSubstraitExpr; + } + public int conflictRetries() { return conflictRetries; } @@ -221,6 +249,26 @@ public boolean skipAutoCleanup() { return skipAutoCleanup; } + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("on", on) + .add("whenMatched", whenMatched) + .add("whenMatchedUpdateExpr", whenMatchedUpdateExpr.orElse(null)) + .add("whenNotMatched", whenNotMatched) + .add("whenNotMatchedBySource", whenNotMatchedBySource) + .add("whenNotMatchedBySourceDeleteExpr", whenNotMatchedBySourceDeleteExpr.orElse(null)) + .add( + "whenNotMatchedBySourceDeleteSubstraitExpr", + whenNotMatchedBySourceDeleteSubstraitExpr + .map(buf -> "ByteBuffer[" + buf.remaining() + " bytes]") + .orElse(null)) + .add("conflictRetries", conflictRetries) + .add("retryTimeoutMs", retryTimeoutMs) + .add("skipAutoCleanup", skipAutoCleanup) + .toString(); + } + public enum WhenMatched { /** * The row is deleted from the target table and a new row is inserted based on the source table. diff --git a/java/src/main/java/com/lancedb/lance/merge/MergeInsertStats.java b/java/src/main/java/com/lancedb/lance/merge/MergeInsertStats.java index 58b8c15574a..45a47742d3b 100644 --- a/java/src/main/java/com/lancedb/lance/merge/MergeInsertStats.java +++ b/java/src/main/java/com/lancedb/lance/merge/MergeInsertStats.java @@ -13,6 +13,8 @@ */ package com.lancedb.lance.merge; +import com.google.common.base.MoreObjects; + public final class MergeInsertStats { private final long numInsertedRows; private final long numUpdatedRows; @@ -62,19 +64,13 @@ public long numFilesWritten() { @Override public String toString() { - return "MergeInsertStats{" - + "numInsertedRows=" - + numInsertedRows - + ", numUpdatedRows=" - + numUpdatedRows - + ", numDeletedRows=" - + numDeletedRows - + ", numAttempts=" - + numAttempts - + ", bytesWritten=" - + bytesWritten - + ", numFilesWritten=" - + numFilesWritten - + '}'; + return MoreObjects.toStringHelper(this) + .add("numInsertedRows", numInsertedRows) + .add("numUpdatedRows", numUpdatedRows) + .add("numDeletedRows", numDeletedRows) + .add("numAttempts", numAttempts) + .add("bytesWritten", bytesWritten) + .add("numFilesWritten", numFilesWritten) + .toString(); } } diff --git a/java/src/test/java/com/lancedb/lance/MergeInsertTest.java b/java/src/test/java/com/lancedb/lance/MergeInsertTest.java new file mode 100644 index 00000000000..71b6d41cd20 --- /dev/null +++ b/java/src/test/java/com/lancedb/lance/MergeInsertTest.java @@ -0,0 +1,217 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance; + +import com.lancedb.lance.merge.MergeInsertParams; +import com.lancedb.lance.merge.MergeInsertResult; +import com.lancedb.lance.operation.OperationTestBase; + +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.TreeMap; +import java.util.UUID; + +public class MergeInsertTest extends OperationTestBase { + @TempDir private Path tempDir; + private RootAllocator allocator; + private TestUtils.SimpleTestDataset testDataset; + private Dataset dataset; + + @BeforeEach + public void setup() { + String datasetPath = tempDir.resolve(UUID.randomUUID().toString()).toString(); + allocator = new RootAllocator(Long.MAX_VALUE); + testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + dataset = testDataset.write(1, 5); + } + + @AfterEach + public void tearDown() { + dataset.close(); + allocator.close(); + } + + @Test + public void testWhenNotMatchedInsertAll() throws Exception { + // Test insert all unmatched source rows + + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")), sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2, 3=Person 3, 4=Person 4, 7=Source 7, 8=Source 8, 9=Source 9}", + readAll(result.dataset()).toString()); + } + } + } + + @Test + public void testWhenNotMatchedDoNothing() throws Exception { + // Test ignore unmatched source rows + + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")) + .withMatchedUpdateAll() + .withNotMatched(MergeInsertParams.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Source 0, 1=Source 1, 2=Source 2, 3=Person 3, 4=Person 4}", + readAll(result.dataset()).toString()); + } + } + } + + @Test + public void testWhenMatchedUpdateIf() throws Exception { + // Test update matched rows if expression is true + + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")) + .withMatchedUpdateIf("target.name = 'Person 0' or target.name = 'Person 1'") + .withNotMatched(MergeInsertParams.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Source 0, 1=Source 1, 2=Person 2, 3=Person 3, 4=Person 4}", + readAll(result.dataset()).toString()); + } + } + } + + @Test + public void testWhenNotMatchedBySourceDelete() throws Exception { + // Test delete target rows which are not matched with source. + + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")) + .withNotMatchedBySourceDelete() + .withNotMatched(MergeInsertParams.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2}", readAll(result.dataset()).toString()); + } + } + } + + @Test + public void testWhenNotMatchedBySourceDeleteIf() throws Exception { + // Test delete target rows which are not matched with source if expression is true + + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")) + .withNotMatchedBySourceDeleteIf("name = 'Person 3'") + .withNotMatched(MergeInsertParams.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2, 4=Person 4}", + readAll(result.dataset()).toString()); + } + } + } + + private VectorSchemaRoot buildSource(Schema schema, RootAllocator allocator) { + List sourceIds = Arrays.asList(0, 1, 2, 7, 8, 9); + + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + root.allocateNew(); + + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + + for (int i = 0; i < sourceIds.size(); i++) { + idVector.setSafe(i, sourceIds.get(i)); + String name = "Source " + sourceIds.get(i); + nameVector.setSafe(i, name.getBytes(StandardCharsets.UTF_8)); + } + + root.setRowCount(sourceIds.size()); + + return root; + } + + private ArrowArrayStream convertToStream(VectorSchemaRoot root, RootAllocator allocator) + throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + ArrowStreamReader reader = new ArrowStreamReader(in, allocator); + + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator); + Data.exportArrayStream(allocator, reader, stream); + + return stream; + } + + private TreeMap readAll(Dataset dataset) throws Exception { + try (ArrowReader reader = dataset.newScan().scanBatches()) { + TreeMap map = new TreeMap<>(); + + while (reader.loadNextBatch()) { + VectorSchemaRoot batch = reader.getVectorSchemaRoot(); + for (int i = 0; i < batch.getRowCount(); i++) { + IntVector idVector = (IntVector) batch.getVector("id"); + VarCharVector nameVector = (VarCharVector) batch.getVector("name"); + map.put(idVector.get(i), new String(nameVector.get(i))); + } + } + + return map; + } + } +} diff --git a/java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java b/java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java deleted file mode 100644 index 12eed70c931..00000000000 --- a/java/src/test/java/com/lancedb/lance/operation/MergeInsertTest.java +++ /dev/null @@ -1,244 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.lancedb.lance.operation; - -import com.lancedb.lance.Dataset; -import com.lancedb.lance.TestUtils; -import com.lancedb.lance.merge.MergeInsert; -import com.lancedb.lance.merge.MergeInsertResult; - -import org.apache.arrow.c.ArrowArrayStream; -import org.apache.arrow.c.Data; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.VarCharVector; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.ipc.ArrowReader; -import org.apache.arrow.vector.ipc.ArrowStreamReader; -import org.apache.arrow.vector.ipc.ArrowStreamWriter; -import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.nio.charset.StandardCharsets; -import java.nio.file.Path; -import java.util.Arrays; -import java.util.List; -import java.util.TreeMap; - -public class MergeInsertTest extends OperationTestBase { - - @Test - public void testWhenNotMatchedInsertAll(@TempDir Path tempDir) throws Exception { - // Test insert all unmatched source rows - - String datasetPath = tempDir.resolve("testWhenNotMatchedInsertAll").toString(); - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - TestUtils.SimpleTestDataset testDataset = - new TestUtils.SimpleTestDataset(allocator, datasetPath); - - int rowCount = 5; - try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { - - try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { - try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { - MergeInsertResult result = - initialDataset.mergeInsert(new MergeInsert(Arrays.asList("id")), sourceStream); - - Assertions.assertEquals( - "{0=Person 0, 1=Person 1, 2=Person 2, 3=Person 3, 4=Person 4, 7=Source 7, 8=Source 8, 9=Source 9}", - readAll(result.dataset()).toString()); - } - } - } - } - } - - @Test - public void testWhenNotMatchedDoNothing(@TempDir Path tempDir) throws Exception { - // Test ignore unmatched source rows - - String datasetPath = tempDir.resolve("testWhenNotMatchedDoNothing").toString(); - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - TestUtils.SimpleTestDataset testDataset = - new TestUtils.SimpleTestDataset(allocator, datasetPath); - - int rowCount = 5; - try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { - - try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { - try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { - MergeInsertResult result = - initialDataset.mergeInsert( - new MergeInsert(Arrays.asList("id")) - .withMatchedUpdateAll() - .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), - sourceStream); - - Assertions.assertEquals( - "{0=Source 0, 1=Source 1, 2=Source 2, 3=Person 3, 4=Person 4}", - readAll(result.dataset()).toString()); - } - } - } - } - } - - @Test - public void testWhenMatchedUpdateIf(@TempDir Path tempDir) throws Exception { - // Test update matched rows if expression is true - - String datasetPath = tempDir.resolve("testWhenMatchedUpdateIf").toString(); - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - TestUtils.SimpleTestDataset testDataset = - new TestUtils.SimpleTestDataset(allocator, datasetPath); - - int rowCount = 5; - try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { - - try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { - try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { - MergeInsertResult result = - initialDataset.mergeInsert( - new MergeInsert(Arrays.asList("id")) - .withMatchedUpdateIf("target.name = 'Person 0' or target.name = 'Person 1'") - .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), - sourceStream); - - Assertions.assertEquals( - "{0=Source 0, 1=Source 1, 2=Person 2, 3=Person 3, 4=Person 4}", - readAll(result.dataset()).toString()); - } - } - } - } - } - - @Test - public void testWhenNotMatchedBySourceDelete(@TempDir Path tempDir) throws Exception { - // Test delete target rows which are not matched with source. - - String datasetPath = tempDir.resolve("testWhenNotMatchedBySourceDelete").toString(); - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - TestUtils.SimpleTestDataset testDataset = - new TestUtils.SimpleTestDataset(allocator, datasetPath); - - int rowCount = 5; - try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { - - try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { - try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { - MergeInsertResult result = - initialDataset.mergeInsert( - new MergeInsert(Arrays.asList("id")) - .withNotMatchedBySourceDelete() - .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), - sourceStream); - - Assertions.assertEquals( - "{0=Person 0, 1=Person 1, 2=Person 2}", readAll(result.dataset()).toString()); - } - } - } - } - } - - @Test - public void testWhenNotMatchedBySourceDeleteIf(@TempDir Path tempDir) throws Exception { - // Test delete target rows which are not matched with source if expression is true - - String datasetPath = tempDir.resolve("testWhenNotMatchedBySourceDeleteIf").toString(); - try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - TestUtils.SimpleTestDataset testDataset = - new TestUtils.SimpleTestDataset(allocator, datasetPath); - - int rowCount = 5; - try (Dataset initialDataset = createAndAppendRows(testDataset, rowCount)) { - - try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { - try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { - MergeInsertResult result = - initialDataset.mergeInsert( - new MergeInsert(Arrays.asList("id")) - .withNotMatchedBySourceDeleteIf("name = 'Person 3'") - .withNotMatched(MergeInsert.WhenNotMatched.DoNothing), - sourceStream); - - Assertions.assertEquals( - "{0=Person 0, 1=Person 1, 2=Person 2, 4=Person 4}", - readAll(result.dataset()).toString()); - } - } - } - } - } - - private VectorSchemaRoot buildSource(Schema schema, RootAllocator allocator) throws Exception { - List sourceIds = Arrays.asList(0, 1, 2, 7, 8, 9); - - VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); - root.allocateNew(); - - IntVector idVector = (IntVector) root.getVector("id"); - VarCharVector nameVector = (VarCharVector) root.getVector("name"); - - for (int i = 0; i < sourceIds.size(); i++) { - idVector.setSafe(i, sourceIds.get(i)); - String name = "Source " + sourceIds.get(i); - nameVector.setSafe(i, name.getBytes(StandardCharsets.UTF_8)); - } - - root.setRowCount(sourceIds.size()); - - return root; - } - - private ArrowArrayStream convertToStream(VectorSchemaRoot root, RootAllocator allocator) - throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - try (ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { - writer.start(); - writer.writeBatch(); - writer.end(); - } - - ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); - ArrowStreamReader reader = new ArrowStreamReader(in, allocator); - - ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator); - Data.exportArrayStream(allocator, reader, stream); - - return stream; - } - - private TreeMap readAll(Dataset dataset) throws Exception { - try (ArrowReader reader = dataset.newScan().scanBatches()) { - TreeMap map = new TreeMap<>(); - - while (reader.loadNextBatch()) { - VectorSchemaRoot batch = reader.getVectorSchemaRoot(); - for (int i = 0; i < batch.getRowCount(); i++) { - IntVector idVector = (IntVector) batch.getVector("id"); - VarCharVector nameVector = (VarCharVector) batch.getVector("name"); - map.put(idVector.get(i), new String(nameVector.get(i))); - } - } - - return map; - } - } -} From 1355d0c0a7349e4126ddb6c729991550a3114304 Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Thu, 11 Sep 2025 14:28:34 +0800 Subject: [PATCH 7/9] Minor modification --- java/src/test/java/com/lancedb/lance/MergeInsertTest.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/java/src/test/java/com/lancedb/lance/MergeInsertTest.java b/java/src/test/java/com/lancedb/lance/MergeInsertTest.java index 71b6d41cd20..58dd41f8a46 100644 --- a/java/src/test/java/com/lancedb/lance/MergeInsertTest.java +++ b/java/src/test/java/com/lancedb/lance/MergeInsertTest.java @@ -15,8 +15,6 @@ import com.lancedb.lance.merge.MergeInsertParams; import com.lancedb.lance.merge.MergeInsertResult; -import com.lancedb.lance.operation.OperationTestBase; - import org.apache.arrow.c.ArrowArrayStream; import org.apache.arrow.c.Data; import org.apache.arrow.memory.RootAllocator; @@ -43,7 +41,7 @@ import java.util.TreeMap; import java.util.UUID; -public class MergeInsertTest extends OperationTestBase { +public class MergeInsertTest { @TempDir private Path tempDir; private RootAllocator allocator; private TestUtils.SimpleTestDataset testDataset; From bf57354b57e018e59432d0eb90e36aa6359942ad Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Thu, 11 Sep 2025 14:35:20 +0800 Subject: [PATCH 8/9] Fix checkstyle issue --- java/src/test/java/com/lancedb/lance/MergeInsertTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/java/src/test/java/com/lancedb/lance/MergeInsertTest.java b/java/src/test/java/com/lancedb/lance/MergeInsertTest.java index 58dd41f8a46..2e9227260f2 100644 --- a/java/src/test/java/com/lancedb/lance/MergeInsertTest.java +++ b/java/src/test/java/com/lancedb/lance/MergeInsertTest.java @@ -15,6 +15,7 @@ import com.lancedb.lance.merge.MergeInsertParams; import com.lancedb.lance.merge.MergeInsertResult; + import org.apache.arrow.c.ArrowArrayStream; import org.apache.arrow.c.Data; import org.apache.arrow.memory.RootAllocator; From b19ccb78c491e3759b2b04f91de4c9472d9c0afb Mon Sep 17 00:00:00 2001 From: "fangbo.0511" Date: Fri, 12 Sep 2025 10:10:01 +0800 Subject: [PATCH 9/9] revert unnecessary modification --- rust/lance-datafusion/src/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/lance-datafusion/src/planner.rs b/rust/lance-datafusion/src/planner.rs index 43f88fe03b2..ddbce5e5b85 100644 --- a/rust/lance-datafusion/src/planner.rs +++ b/rust/lance-datafusion/src/planner.rs @@ -152,7 +152,7 @@ impl ScalarUDFImpl for CastListF16Udf { } // Adapter that instructs datafusion how lance expects expressions to be interpreted -pub struct LanceContextProvider { +struct LanceContextProvider { options: datafusion::config::ConfigOptions, state: SessionState, expr_planners: Vec>,