Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 22 additions & 74 deletions java/lance-jni/src/blocking_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1039,60 +1039,18 @@ fn inner_open_native<'local>(
let storage_options = to_rust_map(env, &jmap)?;

// Extract storage options provider first (before get_bytes_opt which borrows env)
let storage_options_provider = if !storage_options_provider_obj.is_null() {
// Check if it's an Optional.empty()
let is_present = env
.call_method(&storage_options_provider_obj, "isPresent", "()Z", &[])?
.z()?;
if is_present {
// Get the value from Optional
let provider_obj = env
.call_method(
&storage_options_provider_obj,
"get",
"()Ljava/lang/Object;",
&[],
)?
.l()?;
Some(JavaStorageOptionsProvider::new(env, provider_obj)?)
} else {
None
}
} else {
None
};
let storage_options_provider = env
.get_optional(&storage_options_provider_obj, |env, provider_obj| {
JavaStorageOptionsProvider::new(env, provider_obj)
})?;

let storage_options_provider_arc =
storage_options_provider.map(|v| Arc::new(v) as Arc<dyn StorageOptionsProvider>);

// Extract s3_credentials_refresh_offset_seconds
let s3_credentials_refresh_offset_seconds =
if !s3_credentials_refresh_offset_seconds_obj.is_null() {
let is_present = env
.call_method(
&s3_credentials_refresh_offset_seconds_obj,
"isPresent",
"()Z",
&[],
)?
.z()?;
if is_present {
let value = env
.call_method(
&s3_credentials_refresh_offset_seconds_obj,
"get",
"()Ljava/lang/Object;",
&[],
)?
.l()?;
let long_value = env.call_method(&value, "longValue", "()J", &[])?.j()?;
Some(long_value as u64)
} else {
None
}
} else {
None
};
let s3_credentials_refresh_offset_seconds = env
.get_long_opt(&s3_credentials_refresh_offset_seconds_obj)?
.map(|v| v as u64);

let serialized_manifest = env.get_bytes_opt(&serialized_manifest)?;
let dataset = BlockingDataset::open(
Expand Down Expand Up @@ -1404,7 +1362,7 @@ fn inner_shallow_clone<'local>(
) -> Result<JObject<'local>> {
let target_path_str = target_path.extract(env)?;
let storage_options = env.get_optional(&storage_options, |env, map_obj| {
let jmap = JMap::from_env(env, map_obj)?;
let jmap = JMap::from_env(env, &map_obj)?;
to_rust_map(env, &jmap)
})?;

Expand Down Expand Up @@ -1830,18 +1788,13 @@ fn inner_add_columns_by_sql_expressions(

let rust_transform = NewColumnTransform::SqlExpressions(expressions);

let batch_size = if env.call_method(&batch_size, "isPresent", "()Z", &[])?.z()? {
let batch_size_value = env.get_long_opt(&batch_size)?;
match batch_size_value {
Some(value) => Some(
value
.try_into()
.map_err(|_| Error::input_error("Batch size conversion error".to_string()))?,
),
None => None,
}
} else {
None
let batch_size = match env.get_long_opt(&batch_size)? {
Some(value) => Some(
value
.try_into()
.map_err(|_| Error::input_error("Batch size conversion error".to_string()))?,
),
None => None,
};

let mut dataset_guard =
Expand Down Expand Up @@ -1880,18 +1833,13 @@ fn inner_add_columns_by_reader(

let transform = NewColumnTransform::Reader(Box::new(reader));

let batch_size = if env.call_method(&batch_size, "isPresent", "()Z", &[])?.z()? {
let batch_size_value = env.get_long_opt(&batch_size)?;
match batch_size_value {
Some(value) => Some(
value
.try_into()
.map_err(|_| Error::input_error("Batch size conversion error".to_string()))?,
),
None => None,
}
} else {
None
let batch_size = match env.get_long_opt(&batch_size)? {
Some(value) => Some(
value
.try_into()
.map_err(|_| Error::input_error("Batch size conversion error".to_string()))?,
),
None => None,
};

let mut dataset_guard =
Expand Down
24 changes: 7 additions & 17 deletions java/lance-jni/src/blocking_scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,7 @@ fn inner_create_scanner<'local>(
scanner.with_row_address();
}

let query_is_present = env.call_method(&query_obj, "isPresent", "()Z", &[])?.z()?;

if query_is_present {
let java_obj = env
.call_method(&query_obj, "get", "()Ljava/lang/Object;", &[])?
.l()?;

env.get_optional(&query_obj, |env, java_obj| {
// Set column and key for nearest search
let column = env.get_string_from_method(&java_obj, "getColumn")?;
let key_array = env.get_vec_f32_from_method(&java_obj, "getKey")?;
Expand Down Expand Up @@ -207,17 +201,12 @@ fn inner_create_scanner<'local>(

let use_index = env.get_boolean_from_method(&java_obj, "isUseIndex")?;
scanner.use_index(use_index);
}
scanner.batch_readahead(batch_readahead as usize);
Ok(())
})?;

let column_orders_is_present = env
.call_method(&column_orderings, "isPresent", "()Z", &[])?
.z()?;
if column_orders_is_present {
let java_obj = env
.call_method(&column_orderings, "get", "()Ljava/lang/Object;", &[])?
.l()?;
scanner.batch_readahead(batch_readahead as usize);

env.get_optional(&column_orderings, |env, java_obj| {
let list = env.get_list(&java_obj)?;
let mut iter = list.iter(env)?;
let mut results = Vec::with_capacity(list.size(env)? as usize);
Expand All @@ -233,7 +222,8 @@ fn inner_create_scanner<'local>(
results.push(col_order)
}
scanner.order_by(Some(results))?;
}
Ok(())
})?;

let scanner = BlockingScanner::create(scanner);
scanner.into_java(env)
Expand Down
113 changes: 27 additions & 86 deletions java/lance-jni/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ pub trait JNIEnvExt {

fn get_optional<T, F>(&mut self, obj: &JObject, f: F) -> Result<Option<T>>
where
F: FnOnce(&mut JNIEnv, &JObject) -> Result<T>;
F: FnOnce(&mut JNIEnv, JObject) -> Result<T>;
}

impl JNIEnvExt for JNIEnv<'_> {
Expand Down Expand Up @@ -197,85 +197,63 @@ impl JNIEnvExt for JNIEnv<'_> {
}

fn get_string_opt(&mut self, obj: &JObject) -> Result<Option<String>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_string_obj = java_obj_gen.l()?;
self.get_optional(obj, |env, java_string_obj| {
let jstr = JString::from(java_string_obj);
let val = env.get_string(&jstr)?;
Ok(val.to_str()?.to_string())
})
}

fn get_strings_opt(&mut self, obj: &JObject) -> Result<Option<Vec<String>>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_list_obj = java_obj_gen.l()?;
env.get_strings(&java_list_obj)
})
self.get_optional(obj, |env, java_list_obj| env.get_strings(&java_list_obj))
}

fn get_int_opt(&mut self, obj: &JObject) -> Result<Option<i32>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_int_obj = java_obj_gen.l()?;
self.get_optional(obj, |env, java_int_obj| {
let int_obj = env.call_method(java_int_obj, "intValue", "()I", &[])?;
let int_value = int_obj.i()?;
Ok(int_value)
})
}

fn get_ints_opt(&mut self, obj: &JObject) -> Result<Option<Vec<i32>>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_list_obj = java_obj_gen.l()?;
env.get_integers(&java_list_obj)
})
self.get_optional(obj, |env, java_list_obj| env.get_integers(&java_list_obj))
}

fn get_long_opt(&mut self, obj: &JObject) -> Result<Option<i64>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_long_obj = java_obj_gen.l()?;
self.get_optional(obj, |env, java_long_obj| {
let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[])?;
let long_value = long_obj.j()?;
Ok(long_value)
})
}

fn get_boolean_opt(&mut self, obj: &JObject) -> Result<Option<bool>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_boolean_obj = java_obj_gen.l()?;
self.get_optional(obj, |env, java_boolean_obj| {
let boolean_obj = env.call_method(java_boolean_obj, "booleanValue", "()Z", &[])?;
let boolean_value = boolean_obj.z()?;
Ok(boolean_value)
})
}

fn get_f32_opt(&mut self, obj: &JObject) -> Result<Option<f32>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_float_obj = java_obj_gen.l()?;
self.get_optional(obj, |env, java_float_obj| {
let float_obj = env.call_method(java_float_obj, "floatValue", "()F", &[])?;
let float_value = float_obj.f()?;
Ok(float_value)
})
}

fn get_u64_opt(&mut self, obj: &JObject) -> Result<Option<u64>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_long_obj = java_obj_gen.l()?;
self.get_optional(obj, |env, java_long_obj| {
let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[])?;
let long_value = long_obj.j()?;
Ok(long_value as u64)
})
}

fn get_bytes_opt(&mut self, obj: &JObject) -> Result<Option<&[u8]>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_byte_buffer_obj = java_obj_gen.l()?;
self.get_optional(obj, |env, java_byte_buffer_obj| {
let j_byte_buffer = JByteBuffer::from(java_byte_buffer_obj);
let raw_data = env.get_direct_buffer_address(&j_byte_buffer)?;
let capacity = env.get_direct_buffer_capacity(&j_byte_buffer)?;
Expand All @@ -288,10 +266,7 @@ impl JNIEnvExt for JNIEnv<'_> {
where
F: Fn(&mut JNIEnv, &JObject) -> Result<T>,
{
self.get_optional(obj, |env, opt_obj| {
let list_obj = env
.call_method(opt_obj, "get", "()Ljava/lang/Object;", &[])?
.l()?;
self.get_optional(obj, |env, list_obj| {
let list = env.get_list(&list_obj)?;
let mut iter = list.iter(env)?;
let mut items: Vec<T> = Vec::with_capacity(list.size(env)? as usize);
Expand Down Expand Up @@ -369,24 +344,12 @@ impl JNIEnvExt for JNIEnv<'_> {
T: TryFrom<i32>,
<T as TryFrom<i32>>::Error: std::fmt::Debug,
{
let java_object = self
.call_method(obj, method_name, "()Ljava/util/Optional;", &[])?
.l()?;
let rust_obj = if self
.call_method(&java_object, "isPresent", "()Z", &[])?
.z()?
{
let inner_jobj = self
.call_method(&java_object, "get", "()Ljava/lang/Object;", &[])?
.l()?;
let inner_value = self.call_method(&inner_jobj, "intValue", "()I", &[])?.i()?;
Some(T::try_from(inner_value).map_err(|e| {
self.get_optional_from_method(obj, method_name, |env, inner_jobj| {
let inner_value = env.call_method(&inner_jobj, "intValue", "()I", &[])?.i()?;
T::try_from(inner_value).map_err(|e| {
Error::io_error(format!("Failed to convert from i32 to rust type: {:?}", e))
})?)
} else {
None
};
Ok(rust_obj)
})
})
}

fn get_optional_i64_from_method(
Expand Down Expand Up @@ -414,26 +377,12 @@ impl JNIEnvExt for JNIEnv<'_> {
T: TryFrom<i64>,
<T as TryFrom<i64>>::Error: std::fmt::Debug,
{
let java_object = self
.call_method(obj, method_name, "()Ljava/util/Optional;", &[])?
.l()?;
let rust_obj = if self
.call_method(&java_object, "isPresent", "()Z", &[])?
.z()?
{
let inner_jobj = self
.call_method(&java_object, "get", "()Ljava/lang/Object;", &[])?
.l()?;
let inner_value = self
.call_method(&inner_jobj, "longValue", "()J", &[])?
.j()?;
Some(T::try_from(inner_value).map_err(|e| {
self.get_optional_from_method(obj, method_name, |env, inner_jobj| {
let inner_value = env.call_method(&inner_jobj, "longValue", "()J", &[])?.j()?;
T::try_from(inner_value).map_err(|e| {
Error::io_error(format!("Failed to convert from i32 to rust type: {:?}", e))
})?)
} else {
None
};
Ok(rust_obj)
})
})
}

fn get_optional_string_from_method(
Expand All @@ -460,30 +409,22 @@ impl JNIEnvExt for JNIEnv<'_> {
let optional_obj = self
.call_method(obj, method_name, "()Ljava/util/Optional;", &[])?
.l()?;

if self
.call_method(&optional_obj, "isPresent", "()Z", &[])?
.z()?
{
let inner_obj = self
.call_method(&optional_obj, "get", "()Ljava/lang/Object;", &[])?
.l()?;
f(self, inner_obj).map(Some)
} else {
Ok(None)
}
self.get_optional(&optional_obj, f)
}

fn get_optional<T, F>(&mut self, obj: &JObject, f: F) -> Result<Option<T>>
where
F: FnOnce(&mut JNIEnv, &JObject) -> Result<T>,
F: FnOnce(&mut JNIEnv, JObject) -> Result<T>,
{
if obj.is_null() {
return Ok(None);
}
let is_present = self.call_method(obj, "isPresent", "()Z", &[])?;
if is_present.z()? {
f(self, obj).map(Some)
let inner_obj = self
.call_method(obj, "get", "()Ljava/lang/Object;", &[])?
.l()?;
f(self, inner_obj).map(Some)
} else {
// TODO(lu): put get java object into here cuz can only get java Object
Ok(None)
Expand Down
Loading