Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions java/lance-jni/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub mod ffi;
mod file_reader;
mod file_writer;
mod fragment;
mod merge_insert;
mod schema;
mod sql;
pub mod traits;
Expand Down
265 changes: 265 additions & 0 deletions java/lance-jni/src/merge_insert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use crate::blocking_dataset::{BlockingDataset, NATIVE_DATASET};
use crate::error::Result;
use crate::traits::{FromJString, IntoJava};
use crate::{Error, JNIEnvExt, RT};
use arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
use jni::objects::{JObject, JString, JValueGen};
use jni::sys::jlong;
use jni::JNIEnv;
use lance::dataset::scanner::LanceFilter;
use lance::dataset::{
MergeInsertBuilder, MergeStats, WhenMatched, WhenNotMatched, WhenNotMatchedBySource,
};
use lance_core::datatypes::Schema;
use std::sync::Arc;
use std::time::Duration;

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeMergeInsert<'a>(
mut env: JNIEnv<'a>,
jdataset: JObject, // Dataset object
jparam: JObject, // MergeInsertParams object
batch_address: jlong, // ArrowArrayStream address for source
) -> JObject<'a> {
ok_or_throw!(
env,
inner_merge_insert(&mut env, jdataset, jparam, batch_address)
)
}

#[allow(clippy::too_many_arguments)]
fn inner_merge_insert<'local>(
env: &mut JNIEnv<'local>,
jdataset: JObject,
jparam: JObject,
batch_address: jlong,
) -> Result<JObject<'local>> {
let on = extract_on(env, &jparam)?;
let when_matched = extract_when_matched(env, &jparam)?;
let when_not_matched = extract_when_not_matached(env, &jparam)?;

let when_not_matched_by_source_str = extract_when_not_matched_by_source_str(env, &jparam)?;
let when_not_matched_by_source_delete_expr =
extract_when_not_matched_by_source_delete_expr(env, &jparam)?;

let conflict_retries = extract_conflict_retries(env, &jparam)?;
let retry_timeout_ms = extract_retry_timeout_ms(env, &jparam)?;
let skip_auto_cleanup = extract_skip_auto_cleanup(env, &jparam)?;

let (new_ds, merge_stats) = unsafe {
let dataset = env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET)?;

let when_not_matched_by_source = extract_when_not_matched_by_source(
dataset.inner.schema(),
when_not_matched_by_source_str.as_str(),
when_not_matched_by_source_delete_expr,
)?;

let merge_insert_job = MergeInsertBuilder::try_new(Arc::new(dataset.clone().inner), on)?
.when_matched(when_matched)
.when_not_matched(when_not_matched)
.when_not_matched_by_source(when_not_matched_by_source)
.conflict_retries(conflict_retries)
.retry_timeout(Duration::from_millis(retry_timeout_ms as u64))
.skip_auto_cleanup(skip_auto_cleanup)
.try_build()?;

let stream_ptr = batch_address as *mut FFI_ArrowArrayStream;
let source_stream = ArrowArrayStreamReader::from_raw(stream_ptr)?;

RT.block_on(async move { merge_insert_job.execute_reader(source_stream).await })?
};

MergeResult(
BlockingDataset {
inner: Arc::try_unwrap(new_ds).unwrap(),
},
merge_stats,
)
.into_java(env)
}

fn extract_on<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result<Vec<String>> {
let on: JObject = env
.call_method(jparam, "on", "()Ljava/util/List;", &[])?
.l()?;
env.get_strings(&on)
}

fn extract_when_matched<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result<WhenMatched> {
let when_matched: JString = env
.call_method(jparam, "whenMatchedValue", "()Ljava/lang/String;", &[])?
.l()?
.into();
let when_matched = when_matched.extract(env)?;

let when_matched_update_expr = env
.call_method(
jparam,
"whenMatchedUpdateExpr",
"()Ljava/util/Optional;",
&[],
)?
.l()?;
let when_matched_update_expr = env.get_string_opt(&when_matched_update_expr)?;

match when_matched.as_str() {
"UpdateAll" => Ok(WhenMatched::UpdateAll),
"DoNothing" => Ok(WhenMatched::DoNothing),
"UpdateIf" => match when_matched_update_expr {
Some(expr) => Ok(WhenMatched::UpdateIf(expr)),
None => Err(Error::input_error("No matched updated expr".to_string())),
},
_ => Err(Error::input_error(format!(
"Illegal when_matched: {when_matched}",
))),
}
}

fn extract_when_not_matached<'local>(
env: &mut JNIEnv<'local>,
jparam: &JObject,
) -> Result<WhenNotMatched> {
let when_not_matched: JString = env
.call_method(jparam, "whenNotMatchedValue", "()Ljava/lang/String;", &[])?
.l()?
.into();
let when_not_matched = when_not_matched.extract(env)?;

match when_not_matched.as_str() {
"InsertAll" => Ok(WhenNotMatched::InsertAll),
"DoNothing" => Ok(WhenNotMatched::DoNothing),
_ => Err(Error::input_error(format!(
"Illegal when_not_matched: {when_not_matched}",
))),
}
}

fn extract_when_not_matched_by_source_str<'local>(
env: &mut JNIEnv<'local>,
jparam: &JObject,
) -> Result<String> {
let when_not_matched_by_source: JString = env
.call_method(
jparam,
"whenNotMatchedBySourceValue",
"()Ljava/lang/String;",
&[],
)?
.l()?
.into();
when_not_matched_by_source.extract(env)
}

fn extract_when_not_matched_by_source_delete_expr<'local>(
env: &mut JNIEnv<'local>,
jparam: &JObject,
) -> Result<Option<LanceFilter>> {
let when_not_matched_by_source_delete_expr = env
.call_method(
jparam,
"whenNotMatchedBySourceDeleteExpr",
"()Ljava/util/Optional;",
&[],
)?
.l()?;

if let Some(expr) = env.get_string_opt(&when_not_matched_by_source_delete_expr)? {
return Ok(Some(LanceFilter::Sql(expr)));
}

let when_not_matched_by_source_delete_substrait_expr = env
.call_method(
jparam,
"whenNotMatchedBySourceDeleteSubstraitExpr",
"()Ljava/util/Optional;",
&[],
)?
.l()?;

match env.get_bytes_opt(&when_not_matched_by_source_delete_substrait_expr)? {
Some(expr) => Ok(Some(LanceFilter::Substrait(expr.to_vec()))),
None => Ok(None),
}
}

fn extract_when_not_matched_by_source(
schema: &Schema,
when_not_matched_by_source: &str,
when_not_matched_by_source_delete_expr: Option<LanceFilter>,
) -> Result<WhenNotMatchedBySource> {
match when_not_matched_by_source {
"Keep" => Ok(WhenNotMatchedBySource::Keep),
"Delete" => Ok(WhenNotMatchedBySource::Delete),
"DeleteIf" => match when_not_matched_by_source_delete_expr {
Some(expr) => Ok(WhenNotMatchedBySource::DeleteIf(
expr.to_datafusion(schema, schema)?,
)),
None => Err(Error::input_error(format!(
"No delete expr when not matched by source is: {when_not_matched_by_source}",
))),
},
_ => Err(Error::input_error(format!(
"Illegal when_not_matched_by_source: {when_not_matched_by_source}",
))),
}
}

fn extract_conflict_retries<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result<u32> {
let retries = env
.call_method(jparam, "conflictRetries", "()I", &[])?
.i()? as u32;
Ok(retries)
}

fn extract_retry_timeout_ms<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result<u64> {
let timeout_ms = env.call_method(jparam, "retryTimeoutMs", "()J", &[])?.j()? as u64;
Ok(timeout_ms)
}

fn extract_skip_auto_cleanup<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Result<bool> {
let skip_auto_cleanup = env
.call_method(jparam, "skipAutoCleanup", "()Z", &[])?
.z()?;
Ok(skip_auto_cleanup)
}

const MERGE_STATS_CLASS: &str = "com/lancedb/lance/merge/MergeInsertStats";
const MERGE_STATS_CONSTRUCTOR_SIG: &str = "(JJJIJJ)V";
const MERGE_RESULT_CLASS: &str = "com/lancedb/lance/merge/MergeInsertResult";
const MERGE_RESULT_CONSTRUCTOR_SIG: &str =
"(Lcom/lancedb/lance/Dataset;Lcom/lancedb/lance/merge/MergeInsertStats;)V";

impl IntoJava for MergeStats {
fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result<JObject<'a>> {
Ok(env.new_object(
MERGE_STATS_CLASS,
MERGE_STATS_CONSTRUCTOR_SIG,
&[
JValueGen::Long(self.num_inserted_rows as i64),
JValueGen::Long(self.num_updated_rows as i64),
JValueGen::Long(self.num_deleted_rows as i64),
JValueGen::Int(self.num_attempts as i32),
JValueGen::Long(self.bytes_written as i64),
JValueGen::Long(self.num_files_written as i64),
],
)?)
}
}

struct MergeResult(BlockingDataset, MergeStats);

impl IntoJava for MergeResult {
fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result<JObject<'a>> {
let jdataset = self.0.into_java(env)?;
let jstats = self.1.into_java(env)?;
Ok(env.new_object(
MERGE_RESULT_CLASS,
MERGE_RESULT_CONSTRUCTOR_SIG,
&[JValueGen::Object(&jdataset), JValueGen::Object(&jstats)],
)?)
}
}
33 changes: 33 additions & 0 deletions java/src/main/java/com/lancedb/lance/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import com.lancedb.lance.ipc.DataStatistics;
import com.lancedb.lance.ipc.LanceScanner;
import com.lancedb.lance.ipc.ScanOptions;
import com.lancedb.lance.merge.MergeInsertParams;
import com.lancedb.lance.merge.MergeInsertResult;
import com.lancedb.lance.schema.ColumnAlteration;
import com.lancedb.lance.schema.LanceSchema;
import com.lancedb.lance.schema.SqlExpressions;
Expand Down Expand Up @@ -961,6 +963,37 @@ public SqlQuery sql(String sql) {
return new SqlQuery(this, sql);
}

/**
* Merge source data with the existing target data.
*
* <p>This will take in the source, merge it with the existing target data, and insert new rows,
* update existing rows, and delete existing rows.
*
* <p>It is important that after merge insert, the current dataset is changed and should be
* closed. The merged new dataset is contained in the MergeInsertResult.
*
* @param mergeInsert merge insert options
* @param source ArrowArrayStream source data
* @return MergeInsertResult containing the new merged Dataset.
*/
public MergeInsertResult mergeInsert(MergeInsertParams mergeInsert, ArrowArrayStream source) {
try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) {
MergeInsertResult result = nativeMergeInsert(mergeInsert, source.memoryAddress());

Dataset newDataset = result.dataset();
if (selfManagedAllocator) {
newDataset.allocator = new RootAllocator(Long.MAX_VALUE);
} else {
newDataset.allocator = allocator;
}

return result;
}
}

private native MergeInsertResult nativeMergeInsert(
MergeInsertParams mergeInsert, long arrowStreamMemoryAddress);

private native void nativeCreateTag(String tag, long version);

private native void nativeDeleteTag(String tag);
Expand Down
Loading