diff --git a/java/lance-jni/src/lib.rs b/java/lance-jni/src/lib.rs index 689b2e3e41f..5f913bae896 100644 --- a/java/lance-jni/src/lib.rs +++ b/java/lance-jni/src/lib.rs @@ -46,6 +46,7 @@ pub mod ffi; mod file_reader; mod file_writer; mod fragment; +mod merge_insert; mod schema; mod sql; pub mod traits; diff --git a/java/lance-jni/src/merge_insert.rs b/java/lance-jni/src/merge_insert.rs new file mode 100644 index 00000000000..b1ff490444f --- /dev/null +++ b/java/lance-jni/src/merge_insert.rs @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use crate::blocking_dataset::{BlockingDataset, NATIVE_DATASET}; +use crate::error::Result; +use crate::traits::{FromJString, IntoJava}; +use crate::{Error, JNIEnvExt, RT}; +use arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; +use jni::objects::{JObject, JString, JValueGen}; +use jni::sys::jlong; +use jni::JNIEnv; +use lance::dataset::scanner::LanceFilter; +use lance::dataset::{ + MergeInsertBuilder, MergeStats, WhenMatched, WhenNotMatched, WhenNotMatchedBySource, +}; +use lance_core::datatypes::Schema; +use std::sync::Arc; +use std::time::Duration; + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeMergeInsert<'a>( + mut env: JNIEnv<'a>, + jdataset: JObject, // Dataset object + jparam: JObject, // MergeInsertParams object + batch_address: jlong, // ArrowArrayStream address for source +) -> JObject<'a> { + ok_or_throw!( + env, + inner_merge_insert(&mut env, jdataset, jparam, batch_address) + ) +} + +#[allow(clippy::too_many_arguments)] +fn inner_merge_insert<'local>( + env: &mut JNIEnv<'local>, + jdataset: JObject, + jparam: JObject, + batch_address: jlong, +) -> Result> { + let on = extract_on(env, &jparam)?; + let when_matched = extract_when_matched(env, &jparam)?; + let when_not_matched = extract_when_not_matached(env, &jparam)?; + + let when_not_matched_by_source_str = extract_when_not_matched_by_source_str(env, &jparam)?; + let when_not_matched_by_source_delete_expr = + extract_when_not_matched_by_source_delete_expr(env, &jparam)?; + + let conflict_retries = extract_conflict_retries(env, &jparam)?; + let retry_timeout_ms = extract_retry_timeout_ms(env, &jparam)?; + let skip_auto_cleanup = extract_skip_auto_cleanup(env, &jparam)?; + + let (new_ds, merge_stats) = unsafe { + let dataset = env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET)?; + + let when_not_matched_by_source = extract_when_not_matched_by_source( + dataset.inner.schema(), + when_not_matched_by_source_str.as_str(), + when_not_matched_by_source_delete_expr, + )?; + + let merge_insert_job = MergeInsertBuilder::try_new(Arc::new(dataset.clone().inner), on)? + .when_matched(when_matched) + .when_not_matched(when_not_matched) + .when_not_matched_by_source(when_not_matched_by_source) + .conflict_retries(conflict_retries) + .retry_timeout(Duration::from_millis(retry_timeout_ms as u64)) + .skip_auto_cleanup(skip_auto_cleanup) + .try_build()?; + + let stream_ptr = batch_address as *mut FFI_ArrowArrayStream; + let source_stream = ArrowArrayStreamReader::from_raw(stream_ptr)?; + + RT.block_on(async move { merge_insert_job.execute_reader(source_stream).await })? + }; + + MergeResult( + BlockingDataset { + inner: Arc::try_unwrap(new_ds).unwrap(), + }, + merge_stats, + ) + .into_java(env) +} + +fn extract_on<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result> { + let on: JObject = env + .call_method(jparam, "on", "()Ljava/util/List;", &[])? + .l()?; + env.get_strings(&on) +} + +fn extract_when_matched<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result { + let when_matched: JString = env + .call_method(jparam, "whenMatchedValue", "()Ljava/lang/String;", &[])? + .l()? + .into(); + let when_matched = when_matched.extract(env)?; + + let when_matched_update_expr = env + .call_method( + jparam, + "whenMatchedUpdateExpr", + "()Ljava/util/Optional;", + &[], + )? + .l()?; + let when_matched_update_expr = env.get_string_opt(&when_matched_update_expr)?; + + match when_matched.as_str() { + "UpdateAll" => Ok(WhenMatched::UpdateAll), + "DoNothing" => Ok(WhenMatched::DoNothing), + "UpdateIf" => match when_matched_update_expr { + Some(expr) => Ok(WhenMatched::UpdateIf(expr)), + None => Err(Error::input_error("No matched updated expr".to_string())), + }, + _ => Err(Error::input_error(format!( + "Illegal when_matched: {when_matched}", + ))), + } +} + +fn extract_when_not_matached<'local>( + env: &mut JNIEnv<'local>, + jparam: &JObject, +) -> Result { + let when_not_matched: JString = env + .call_method(jparam, "whenNotMatchedValue", "()Ljava/lang/String;", &[])? + .l()? + .into(); + let when_not_matched = when_not_matched.extract(env)?; + + match when_not_matched.as_str() { + "InsertAll" => Ok(WhenNotMatched::InsertAll), + "DoNothing" => Ok(WhenNotMatched::DoNothing), + _ => Err(Error::input_error(format!( + "Illegal when_not_matched: {when_not_matched}", + ))), + } +} + +fn extract_when_not_matched_by_source_str<'local>( + env: &mut JNIEnv<'local>, + jparam: &JObject, +) -> Result { + let when_not_matched_by_source: JString = env + .call_method( + jparam, + "whenNotMatchedBySourceValue", + "()Ljava/lang/String;", + &[], + )? + .l()? + .into(); + when_not_matched_by_source.extract(env) +} + +fn extract_when_not_matched_by_source_delete_expr<'local>( + env: &mut JNIEnv<'local>, + jparam: &JObject, +) -> Result> { + let when_not_matched_by_source_delete_expr = env + .call_method( + jparam, + "whenNotMatchedBySourceDeleteExpr", + "()Ljava/util/Optional;", + &[], + )? + .l()?; + + if let Some(expr) = env.get_string_opt(&when_not_matched_by_source_delete_expr)? { + return Ok(Some(LanceFilter::Sql(expr))); + } + + let when_not_matched_by_source_delete_substrait_expr = env + .call_method( + jparam, + "whenNotMatchedBySourceDeleteSubstraitExpr", + "()Ljava/util/Optional;", + &[], + )? + .l()?; + + match env.get_bytes_opt(&when_not_matched_by_source_delete_substrait_expr)? { + Some(expr) => Ok(Some(LanceFilter::Substrait(expr.to_vec()))), + None => Ok(None), + } +} + +fn extract_when_not_matched_by_source( + schema: &Schema, + when_not_matched_by_source: &str, + when_not_matched_by_source_delete_expr: Option, +) -> Result { + match when_not_matched_by_source { + "Keep" => Ok(WhenNotMatchedBySource::Keep), + "Delete" => Ok(WhenNotMatchedBySource::Delete), + "DeleteIf" => match when_not_matched_by_source_delete_expr { + Some(expr) => Ok(WhenNotMatchedBySource::DeleteIf( + expr.to_datafusion(schema, schema)?, + )), + None => Err(Error::input_error(format!( + "No delete expr when not matched by source is: {when_not_matched_by_source}", + ))), + }, + _ => Err(Error::input_error(format!( + "Illegal when_not_matched_by_source: {when_not_matched_by_source}", + ))), + } +} + +fn extract_conflict_retries<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result { + let retries = env + .call_method(jparam, "conflictRetries", "()I", &[])? + .i()? as u32; + Ok(retries) +} + +fn extract_retry_timeout_ms<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result { + let timeout_ms = env.call_method(jparam, "retryTimeoutMs", "()J", &[])?.j()? as u64; + Ok(timeout_ms) +} + +fn extract_skip_auto_cleanup<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result { + let skip_auto_cleanup = env + .call_method(jparam, "skipAutoCleanup", "()Z", &[])? + .z()?; + Ok(skip_auto_cleanup) +} + +const MERGE_STATS_CLASS: &str = "com/lancedb/lance/merge/MergeInsertStats"; +const MERGE_STATS_CONSTRUCTOR_SIG: &str = "(JJJIJJ)V"; +const MERGE_RESULT_CLASS: &str = "com/lancedb/lance/merge/MergeInsertResult"; +const MERGE_RESULT_CONSTRUCTOR_SIG: &str = + "(Lcom/lancedb/lance/Dataset;Lcom/lancedb/lance/merge/MergeInsertStats;)V"; + +impl IntoJava for MergeStats { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + Ok(env.new_object( + MERGE_STATS_CLASS, + MERGE_STATS_CONSTRUCTOR_SIG, + &[ + JValueGen::Long(self.num_inserted_rows as i64), + JValueGen::Long(self.num_updated_rows as i64), + JValueGen::Long(self.num_deleted_rows as i64), + JValueGen::Int(self.num_attempts as i32), + JValueGen::Long(self.bytes_written as i64), + JValueGen::Long(self.num_files_written as i64), + ], + )?) + } +} + +struct MergeResult(BlockingDataset, MergeStats); + +impl IntoJava for MergeResult { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let jdataset = self.0.into_java(env)?; + let jstats = self.1.into_java(env)?; + Ok(env.new_object( + MERGE_RESULT_CLASS, + MERGE_RESULT_CONSTRUCTOR_SIG, + &[JValueGen::Object(&jdataset), JValueGen::Object(&jstats)], + )?) + } +} diff --git a/java/src/main/java/com/lancedb/lance/Dataset.java b/java/src/main/java/com/lancedb/lance/Dataset.java index 423e1b234be..2dd508c5e1a 100644 --- a/java/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/src/main/java/com/lancedb/lance/Dataset.java @@ -18,6 +18,8 @@ import com.lancedb.lance.ipc.DataStatistics; import com.lancedb.lance.ipc.LanceScanner; import com.lancedb.lance.ipc.ScanOptions; +import com.lancedb.lance.merge.MergeInsertParams; +import com.lancedb.lance.merge.MergeInsertResult; import com.lancedb.lance.schema.ColumnAlteration; import com.lancedb.lance.schema.LanceSchema; import com.lancedb.lance.schema.SqlExpressions; @@ -961,6 +963,37 @@ public SqlQuery sql(String sql) { return new SqlQuery(this, sql); } + /** + * Merge source data with the existing target data. + * + *

This will take in the source, merge it with the existing target data, and insert new rows, + * update existing rows, and delete existing rows. + * + *

It is important that after merge insert, the current dataset is changed and should be + * closed. The merged new dataset is contained in the MergeInsertResult. + * + * @param mergeInsert merge insert options + * @param source ArrowArrayStream source data + * @return MergeInsertResult containing the new merged Dataset. + */ + public MergeInsertResult mergeInsert(MergeInsertParams mergeInsert, ArrowArrayStream source) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + MergeInsertResult result = nativeMergeInsert(mergeInsert, source.memoryAddress()); + + Dataset newDataset = result.dataset(); + if (selfManagedAllocator) { + newDataset.allocator = new RootAllocator(Long.MAX_VALUE); + } else { + newDataset.allocator = allocator; + } + + return result; + } + } + + private native MergeInsertResult nativeMergeInsert( + MergeInsertParams mergeInsert, long arrowStreamMemoryAddress); + private native void nativeCreateTag(String tag, long version); private native void nativeDeleteTag(String tag); diff --git a/java/src/main/java/com/lancedb/lance/merge/MergeInsertParams.java b/java/src/main/java/com/lancedb/lance/merge/MergeInsertParams.java new file mode 100644 index 00000000000..9f63a5e3baa --- /dev/null +++ b/java/src/main/java/com/lancedb/lance/merge/MergeInsertParams.java @@ -0,0 +1,316 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.merge; + +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Optional; + +public class MergeInsertParams { + private final List on; + + private WhenMatched whenMatched = WhenMatched.DoNothing; + private Optional whenMatchedUpdateExpr = Optional.empty(); + + private WhenNotMatched whenNotMatched = WhenNotMatched.InsertAll; + + private WhenNotMatchedBySource whenNotMatchedBySource = WhenNotMatchedBySource.Keep; + private Optional whenNotMatchedBySourceDeleteExpr = Optional.empty(); + private Optional whenNotMatchedBySourceDeleteSubstraitExpr = Optional.empty(); + + private int conflictRetries = 10; + private long retryTimeoutMs = 30 * 1000; + private boolean skipAutoCleanup = false; + + public MergeInsertParams(List on) { + this.on = on; + } + + /** + * Specify that when a row in the source table matches a row in the target table, the row is + * deleted from the target table and the matched row based on the source table is inserted. + * + *

This can be used to achieve upsert behavior. + * + * @return This MergeInsertParams instance + */ + public MergeInsertParams withMatchedUpdateAll() { + this.whenMatched = WhenMatched.UpdateAll; + return this; + } + + /** + * Specify that when a row in the source table matches a row in the target table, the row in the + * target table is kept unchanged. + * + *

This can be used to achieve find-or-create behavior. + * + * @return This MergeInsertParams instance + */ + public MergeInsertParams withMatchedDoNothing() { + this.whenMatched = WhenMatched.DoNothing; + return this; + } + + /** + * Specify that when a row in the source table matches a row in the target table and the + * expression evaluates to true, the row in the target table is updated by the matched row from + * the source table. + * + *

This can be used to achieve upsert behavior. + * + *

The expression can reference source tables' columns with source. and target + * tables' columns with target. This is an example: + * source.column1 = target.column1 AND source.column2 = target.column2 + * + * @param expr The expression to evaluate on the rows in the source table and target table. + * @return This MergeInsertParams instance + */ + public MergeInsertParams withMatchedUpdateIf(String expr) { + Preconditions.checkNotNull(expr); + this.whenMatched = WhenMatched.UpdateIf; + this.whenMatchedUpdateExpr = Optional.of(expr); + return this; + } + + /** + * Specify what should happen when a source row has no match in the target. + * + * @param whenNotMatched The action to take when a source row has no match in the target. + * @return This MergeInsertParams instance + */ + public MergeInsertParams withNotMatched(WhenNotMatched whenNotMatched) { + this.whenNotMatched = whenNotMatched; + return this; + } + + /** + * Specify that when a target row has no match in the source, the row is kept in the target table. + * + * @return This MergeInsertParams instance + */ + public MergeInsertParams withNotMatchedBySourceKeep() { + this.whenNotMatchedBySource = WhenNotMatchedBySource.Keep; + return this; + } + + /** + * Specify that when a target row has no match in the source, the row is deleted from the target + * table. + * + * @return This MergeInsertParams instance + */ + public MergeInsertParams withNotMatchedBySourceDelete() { + this.whenNotMatchedBySource = WhenNotMatchedBySource.Delete; + return this; + } + + /** + * Specify that when a target row has no match in the source and the expression evaluates to true, + * the row is deleted from the target table. + * + * @param expr The sql expression to evaluate on the rows in the target table. + * @return This MergeInsertParams instance + */ + public MergeInsertParams withNotMatchedBySourceDeleteIf(String expr) { + Preconditions.checkNotNull(expr); + this.whenNotMatchedBySource = WhenNotMatchedBySource.DeleteIf; + this.whenNotMatchedBySourceDeleteExpr = Optional.of(expr); + this.whenNotMatchedBySourceDeleteSubstraitExpr = Optional.empty(); + return this; + } + + /** + * Specify that when a target row has no match in the source and the expression evaluates to true, + * the row is deleted from the target table. + * + * @param expr The substrait expression to evaluate on the rows in the target table. + * @return This MergeInsertParams instance + */ + public MergeInsertParams withNotMatchedBySourceDeleteSubstraitIf(ByteBuffer expr) { + Preconditions.checkNotNull(expr); + this.whenNotMatchedBySource = WhenNotMatchedBySource.DeleteIf; + this.whenNotMatchedBySourceDeleteExpr = Optional.empty(); + this.whenNotMatchedBySourceDeleteSubstraitExpr = Optional.of(expr); + return this; + } + + /** + * Set number of times to retry the operation if there is contention. + * + *

If this is set greater than 0, then the operation will keep a copy of the input data either + * in memory or on disk (depending on the size of the data) and will retry the operation if there + * is contention. + * + *

Default is 10. + * + * @param retries Number of times to retry the operation if there is contention. + * @return This MergeInsertParams instance + */ + public MergeInsertParams withConflictRetries(int retries) { + this.conflictRetries = retries; + return this; + } + + /** + * Set the timeout in milliseconds used to limit retries. + * + *

This is the maximum time to spend on the operation before giving up. At least one attempt + * will be made, regardless of how long it takes to complete. Subsequent attempts will be + * cancelled once this timeout is reached. If the timeout has been reached during the first + * attempt, the operation will be cancelled immediately. + * + *

Default is 30000. + * + * @param timeoutMs Timeout in milliseconds used to limit retries. + * @return This MergeInsertParams instance + */ + public MergeInsertParams withRetryTimeoutMs(long timeoutMs) { + this.retryTimeoutMs = timeoutMs; + return this; + } + + /** + * If true, skip auto cleanup during commits. This should be set to true for high frequency writes + * to improve performance. This is also useful if the writer does not have delete permissions and + * the clean up would just try and log a failure anyway. + * + * @param skipAutoCleanup Whether to skip auto cleanup during commits. + * @return This MergeInsertParams instance + */ + public MergeInsertParams withSkipAutoCleanup(boolean skipAutoCleanup) { + this.skipAutoCleanup = skipAutoCleanup; + return this; + } + + public List on() { + return on; + } + + public WhenMatched whenMatched() { + return whenMatched; + } + + public String whenMatchedValue() { + return whenMatched.name(); + } + + public Optional whenMatchedUpdateExpr() { + return whenMatchedUpdateExpr; + } + + public WhenNotMatched whenNotMatched() { + return whenNotMatched; + } + + public String whenNotMatchedValue() { + return whenNotMatched.name(); + } + + public WhenNotMatchedBySource whenNotMatchedBySource() { + return whenNotMatchedBySource; + } + + public String whenNotMatchedBySourceValue() { + return whenNotMatchedBySource.name(); + } + + public Optional whenNotMatchedBySourceDeleteExpr() { + return whenNotMatchedBySourceDeleteExpr; + } + + public Optional whenNotMatchedBySourceDeleteSubstraitExpr() { + return whenNotMatchedBySourceDeleteSubstraitExpr; + } + + public int conflictRetries() { + return conflictRetries; + } + + public long retryTimeoutMs() { + return retryTimeoutMs; + } + + public boolean skipAutoCleanup() { + return skipAutoCleanup; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("on", on) + .add("whenMatched", whenMatched) + .add("whenMatchedUpdateExpr", whenMatchedUpdateExpr.orElse(null)) + .add("whenNotMatched", whenNotMatched) + .add("whenNotMatchedBySource", whenNotMatchedBySource) + .add("whenNotMatchedBySourceDeleteExpr", whenNotMatchedBySourceDeleteExpr.orElse(null)) + .add( + "whenNotMatchedBySourceDeleteSubstraitExpr", + whenNotMatchedBySourceDeleteSubstraitExpr + .map(buf -> "ByteBuffer[" + buf.remaining() + " bytes]") + .orElse(null)) + .add("conflictRetries", conflictRetries) + .add("retryTimeoutMs", retryTimeoutMs) + .add("skipAutoCleanup", skipAutoCleanup) + .toString(); + } + + public enum WhenMatched { + /** + * The row is deleted from the target table and a new row is inserted based on the source table. + * This can be used to achieve upsert behavior. + */ + UpdateAll, + + /** The row is kept unchanged. This can be used to achieve find-or-create behavior. */ + DoNothing, + + /** + * The row is updated (similar to UpdateAll) only for rows where the expression evaluates to + * true. + */ + UpdateIf, + } + + public enum WhenNotMatched { + /** + * The new row is inserted into the target table. This is used in both find-or-create and upsert + * operations + */ + InsertAll, + + /** The new row is ignored. */ + DoNothing, + } + + public enum WhenNotMatchedBySource { + /** + * Do not delete rows from the target table This can be used for a find-or-create or an upsert + * operation + */ + Keep, + + /** Delete all rows from target table that don't match a row in the source table */ + Delete, + + /** + * Delete rows from the target table if there is no match AND the expression evaluates to true + * This can be used to replace a region of data with new data + */ + DeleteIf, + } +} diff --git a/java/src/main/java/com/lancedb/lance/merge/MergeInsertResult.java b/java/src/main/java/com/lancedb/lance/merge/MergeInsertResult.java new file mode 100644 index 00000000000..8e539ac9728 --- /dev/null +++ b/java/src/main/java/com/lancedb/lance/merge/MergeInsertResult.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.merge; + +import com.lancedb.lance.Dataset; + +public class MergeInsertResult { + private final Dataset dataset; + private final MergeInsertStats stats; + + public MergeInsertResult(Dataset dataset, MergeInsertStats stats) { + this.dataset = dataset; + this.stats = stats; + } + + public Dataset dataset() { + return dataset; + } + + public MergeInsertStats stats() { + return stats; + } +} diff --git a/java/src/main/java/com/lancedb/lance/merge/MergeInsertStats.java b/java/src/main/java/com/lancedb/lance/merge/MergeInsertStats.java new file mode 100644 index 00000000000..45a47742d3b --- /dev/null +++ b/java/src/main/java/com/lancedb/lance/merge/MergeInsertStats.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.merge; + +import com.google.common.base.MoreObjects; + +public final class MergeInsertStats { + private final long numInsertedRows; + private final long numUpdatedRows; + private final long numDeletedRows; + private final int numAttempts; + private final long bytesWritten; + private final long numFilesWritten; + + public MergeInsertStats( + long numInsertedRows, + long numUpdatedRows, + long numDeletedRows, + int numAttempts, + long bytesWritten, + long numFilesWritten) { + this.numInsertedRows = numInsertedRows; + this.numUpdatedRows = numUpdatedRows; + this.numDeletedRows = numDeletedRows; + this.numAttempts = numAttempts; + this.bytesWritten = bytesWritten; + this.numFilesWritten = numFilesWritten; + } + + public long numInsertedRows() { + return numInsertedRows; + } + + public long numUpdatedRows() { + return numUpdatedRows; + } + + public long numDeletedRows() { + return numDeletedRows; + } + + public int numAttempts() { + return numAttempts; + } + + public long bytesWritten() { + return bytesWritten; + } + + public long numFilesWritten() { + return numFilesWritten; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("numInsertedRows", numInsertedRows) + .add("numUpdatedRows", numUpdatedRows) + .add("numDeletedRows", numDeletedRows) + .add("numAttempts", numAttempts) + .add("bytesWritten", bytesWritten) + .add("numFilesWritten", numFilesWritten) + .toString(); + } +} diff --git a/java/src/test/java/com/lancedb/lance/MergeInsertTest.java b/java/src/test/java/com/lancedb/lance/MergeInsertTest.java new file mode 100644 index 00000000000..2e9227260f2 --- /dev/null +++ b/java/src/test/java/com/lancedb/lance/MergeInsertTest.java @@ -0,0 +1,216 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance; + +import com.lancedb.lance.merge.MergeInsertParams; +import com.lancedb.lance.merge.MergeInsertResult; + +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.TreeMap; +import java.util.UUID; + +public class MergeInsertTest { + @TempDir private Path tempDir; + private RootAllocator allocator; + private TestUtils.SimpleTestDataset testDataset; + private Dataset dataset; + + @BeforeEach + public void setup() { + String datasetPath = tempDir.resolve(UUID.randomUUID().toString()).toString(); + allocator = new RootAllocator(Long.MAX_VALUE); + testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + dataset = testDataset.write(1, 5); + } + + @AfterEach + public void tearDown() { + dataset.close(); + allocator.close(); + } + + @Test + public void testWhenNotMatchedInsertAll() throws Exception { + // Test insert all unmatched source rows + + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")), sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2, 3=Person 3, 4=Person 4, 7=Source 7, 8=Source 8, 9=Source 9}", + readAll(result.dataset()).toString()); + } + } + } + + @Test + public void testWhenNotMatchedDoNothing() throws Exception { + // Test ignore unmatched source rows + + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")) + .withMatchedUpdateAll() + .withNotMatched(MergeInsertParams.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Source 0, 1=Source 1, 2=Source 2, 3=Person 3, 4=Person 4}", + readAll(result.dataset()).toString()); + } + } + } + + @Test + public void testWhenMatchedUpdateIf() throws Exception { + // Test update matched rows if expression is true + + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")) + .withMatchedUpdateIf("target.name = 'Person 0' or target.name = 'Person 1'") + .withNotMatched(MergeInsertParams.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Source 0, 1=Source 1, 2=Person 2, 3=Person 3, 4=Person 4}", + readAll(result.dataset()).toString()); + } + } + } + + @Test + public void testWhenNotMatchedBySourceDelete() throws Exception { + // Test delete target rows which are not matched with source. + + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")) + .withNotMatchedBySourceDelete() + .withNotMatched(MergeInsertParams.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2}", readAll(result.dataset()).toString()); + } + } + } + + @Test + public void testWhenNotMatchedBySourceDeleteIf() throws Exception { + // Test delete target rows which are not matched with source if expression is true + + try (VectorSchemaRoot source = buildSource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")) + .withNotMatchedBySourceDeleteIf("name = 'Person 3'") + .withNotMatched(MergeInsertParams.WhenNotMatched.DoNothing), + sourceStream); + + Assertions.assertEquals( + "{0=Person 0, 1=Person 1, 2=Person 2, 4=Person 4}", + readAll(result.dataset()).toString()); + } + } + } + + private VectorSchemaRoot buildSource(Schema schema, RootAllocator allocator) { + List sourceIds = Arrays.asList(0, 1, 2, 7, 8, 9); + + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + root.allocateNew(); + + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + + for (int i = 0; i < sourceIds.size(); i++) { + idVector.setSafe(i, sourceIds.get(i)); + String name = "Source " + sourceIds.get(i); + nameVector.setSafe(i, name.getBytes(StandardCharsets.UTF_8)); + } + + root.setRowCount(sourceIds.size()); + + return root; + } + + private ArrowArrayStream convertToStream(VectorSchemaRoot root, RootAllocator allocator) + throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + ArrowStreamReader reader = new ArrowStreamReader(in, allocator); + + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator); + Data.exportArrayStream(allocator, reader, stream); + + return stream; + } + + private TreeMap readAll(Dataset dataset) throws Exception { + try (ArrowReader reader = dataset.newScan().scanBatches()) { + TreeMap map = new TreeMap<>(); + + while (reader.loadNextBatch()) { + VectorSchemaRoot batch = reader.getVectorSchemaRoot(); + for (int i = 0; i < batch.getRowCount(); i++) { + IntVector idVector = (IntVector) batch.getVector("id"); + VarCharVector nameVector = (VarCharVector) batch.getVector("name"); + map.put(idVector.get(i), new String(nameVector.get(i))); + } + } + + return map; + } + } +}