From 978c80180892e305977bfa36d88bf31ec498f7c1 Mon Sep 17 00:00:00 2001 From: Ddupg Date: Mon, 15 Dec 2025 21:00:35 +0800 Subject: [PATCH] feat(java): simplify the use of optional in jni Change-Id: I94ff021eefc1260ce7f4afb2cc5063c829b0c59f --- java/lance-jni/src/blocking_dataset.rs | 96 ++----- java/lance-jni/src/blocking_scanner.rs | 24 +- java/lance-jni/src/ffi.rs | 113 ++------ java/lance-jni/src/fragment.rs | 14 +- java/lance-jni/src/utils.rs | 344 +++++++++++-------------- 5 files changed, 208 insertions(+), 383 deletions(-) diff --git a/java/lance-jni/src/blocking_dataset.rs b/java/lance-jni/src/blocking_dataset.rs index 805c6d71360..55b108dd9cc 100644 --- a/java/lance-jni/src/blocking_dataset.rs +++ b/java/lance-jni/src/blocking_dataset.rs @@ -1039,60 +1039,18 @@ fn inner_open_native<'local>( let storage_options = to_rust_map(env, &jmap)?; // Extract storage options provider first (before get_bytes_opt which borrows env) - let storage_options_provider = if !storage_options_provider_obj.is_null() { - // Check if it's an Optional.empty() - let is_present = env - .call_method(&storage_options_provider_obj, "isPresent", "()Z", &[])? - .z()?; - if is_present { - // Get the value from Optional - let provider_obj = env - .call_method( - &storage_options_provider_obj, - "get", - "()Ljava/lang/Object;", - &[], - )? - .l()?; - Some(JavaStorageOptionsProvider::new(env, provider_obj)?) - } else { - None - } - } else { - None - }; + let storage_options_provider = env + .get_optional(&storage_options_provider_obj, |env, provider_obj| { + JavaStorageOptionsProvider::new(env, provider_obj) + })?; let storage_options_provider_arc = storage_options_provider.map(|v| Arc::new(v) as Arc); // Extract s3_credentials_refresh_offset_seconds - let s3_credentials_refresh_offset_seconds = - if !s3_credentials_refresh_offset_seconds_obj.is_null() { - let is_present = env - .call_method( - &s3_credentials_refresh_offset_seconds_obj, - "isPresent", - "()Z", - &[], - )? - .z()?; - if is_present { - let value = env - .call_method( - &s3_credentials_refresh_offset_seconds_obj, - "get", - "()Ljava/lang/Object;", - &[], - )? - .l()?; - let long_value = env.call_method(&value, "longValue", "()J", &[])?.j()?; - Some(long_value as u64) - } else { - None - } - } else { - None - }; + let s3_credentials_refresh_offset_seconds = env + .get_long_opt(&s3_credentials_refresh_offset_seconds_obj)? + .map(|v| v as u64); let serialized_manifest = env.get_bytes_opt(&serialized_manifest)?; let dataset = BlockingDataset::open( @@ -1404,7 +1362,7 @@ fn inner_shallow_clone<'local>( ) -> Result> { let target_path_str = target_path.extract(env)?; let storage_options = env.get_optional(&storage_options, |env, map_obj| { - let jmap = JMap::from_env(env, map_obj)?; + let jmap = JMap::from_env(env, &map_obj)?; to_rust_map(env, &jmap) })?; @@ -1830,18 +1788,13 @@ fn inner_add_columns_by_sql_expressions( let rust_transform = NewColumnTransform::SqlExpressions(expressions); - let batch_size = if env.call_method(&batch_size, "isPresent", "()Z", &[])?.z()? { - let batch_size_value = env.get_long_opt(&batch_size)?; - match batch_size_value { - Some(value) => Some( - value - .try_into() - .map_err(|_| Error::input_error("Batch size conversion error".to_string()))?, - ), - None => None, - } - } else { - None + let batch_size = match env.get_long_opt(&batch_size)? { + Some(value) => Some( + value + .try_into() + .map_err(|_| Error::input_error("Batch size conversion error".to_string()))?, + ), + None => None, }; let mut dataset_guard = @@ -1880,18 +1833,13 @@ fn inner_add_columns_by_reader( let transform = NewColumnTransform::Reader(Box::new(reader)); - let batch_size = if env.call_method(&batch_size, "isPresent", "()Z", &[])?.z()? { - let batch_size_value = env.get_long_opt(&batch_size)?; - match batch_size_value { - Some(value) => Some( - value - .try_into() - .map_err(|_| Error::input_error("Batch size conversion error".to_string()))?, - ), - None => None, - } - } else { - None + let batch_size = match env.get_long_opt(&batch_size)? { + Some(value) => Some( + value + .try_into() + .map_err(|_| Error::input_error("Batch size conversion error".to_string()))?, + ), + None => None, }; let mut dataset_guard = diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index 4790219b09a..a8f3c807ed7 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -165,13 +165,7 @@ fn inner_create_scanner<'local>( scanner.with_row_address(); } - let query_is_present = env.call_method(&query_obj, "isPresent", "()Z", &[])?.z()?; - - if query_is_present { - let java_obj = env - .call_method(&query_obj, "get", "()Ljava/lang/Object;", &[])? - .l()?; - + env.get_optional(&query_obj, |env, java_obj| { // Set column and key for nearest search let column = env.get_string_from_method(&java_obj, "getColumn")?; let key_array = env.get_vec_f32_from_method(&java_obj, "getKey")?; @@ -207,17 +201,12 @@ fn inner_create_scanner<'local>( let use_index = env.get_boolean_from_method(&java_obj, "isUseIndex")?; scanner.use_index(use_index); - } - scanner.batch_readahead(batch_readahead as usize); + Ok(()) + })?; - let column_orders_is_present = env - .call_method(&column_orderings, "isPresent", "()Z", &[])? - .z()?; - if column_orders_is_present { - let java_obj = env - .call_method(&column_orderings, "get", "()Ljava/lang/Object;", &[])? - .l()?; + scanner.batch_readahead(batch_readahead as usize); + env.get_optional(&column_orderings, |env, java_obj| { let list = env.get_list(&java_obj)?; let mut iter = list.iter(env)?; let mut results = Vec::with_capacity(list.size(env)? as usize); @@ -233,7 +222,8 @@ fn inner_create_scanner<'local>( results.push(col_order) } scanner.order_by(Some(results))?; - } + Ok(()) + })?; let scanner = BlockingScanner::create(scanner); scanner.into_java(env) diff --git a/java/lance-jni/src/ffi.rs b/java/lance-jni/src/ffi.rs index 1c76e899c52..5889e562c6b 100644 --- a/java/lance-jni/src/ffi.rs +++ b/java/lance-jni/src/ffi.rs @@ -145,7 +145,7 @@ pub trait JNIEnvExt { fn get_optional(&mut self, obj: &JObject, f: F) -> Result> where - F: FnOnce(&mut JNIEnv, &JObject) -> Result; + F: FnOnce(&mut JNIEnv, JObject) -> Result; } impl JNIEnvExt for JNIEnv<'_> { @@ -197,9 +197,7 @@ impl JNIEnvExt for JNIEnv<'_> { } fn get_string_opt(&mut self, obj: &JObject) -> Result> { - self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_string_obj = java_obj_gen.l()?; + self.get_optional(obj, |env, java_string_obj| { let jstr = JString::from(java_string_obj); let val = env.get_string(&jstr)?; Ok(val.to_str()?.to_string()) @@ -207,17 +205,11 @@ impl JNIEnvExt for JNIEnv<'_> { } fn get_strings_opt(&mut self, obj: &JObject) -> Result>> { - self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_list_obj = java_obj_gen.l()?; - env.get_strings(&java_list_obj) - }) + self.get_optional(obj, |env, java_list_obj| env.get_strings(&java_list_obj)) } fn get_int_opt(&mut self, obj: &JObject) -> Result> { - self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_int_obj = java_obj_gen.l()?; + self.get_optional(obj, |env, java_int_obj| { let int_obj = env.call_method(java_int_obj, "intValue", "()I", &[])?; let int_value = int_obj.i()?; Ok(int_value) @@ -225,17 +217,11 @@ impl JNIEnvExt for JNIEnv<'_> { } fn get_ints_opt(&mut self, obj: &JObject) -> Result>> { - self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_list_obj = java_obj_gen.l()?; - env.get_integers(&java_list_obj) - }) + self.get_optional(obj, |env, java_list_obj| env.get_integers(&java_list_obj)) } fn get_long_opt(&mut self, obj: &JObject) -> Result> { - self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_long_obj = java_obj_gen.l()?; + self.get_optional(obj, |env, java_long_obj| { let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[])?; let long_value = long_obj.j()?; Ok(long_value) @@ -243,9 +229,7 @@ impl JNIEnvExt for JNIEnv<'_> { } fn get_boolean_opt(&mut self, obj: &JObject) -> Result> { - self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_boolean_obj = java_obj_gen.l()?; + self.get_optional(obj, |env, java_boolean_obj| { let boolean_obj = env.call_method(java_boolean_obj, "booleanValue", "()Z", &[])?; let boolean_value = boolean_obj.z()?; Ok(boolean_value) @@ -253,9 +237,7 @@ impl JNIEnvExt for JNIEnv<'_> { } fn get_f32_opt(&mut self, obj: &JObject) -> Result> { - self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_float_obj = java_obj_gen.l()?; + self.get_optional(obj, |env, java_float_obj| { let float_obj = env.call_method(java_float_obj, "floatValue", "()F", &[])?; let float_value = float_obj.f()?; Ok(float_value) @@ -263,9 +245,7 @@ impl JNIEnvExt for JNIEnv<'_> { } fn get_u64_opt(&mut self, obj: &JObject) -> Result> { - self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_long_obj = java_obj_gen.l()?; + self.get_optional(obj, |env, java_long_obj| { let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[])?; let long_value = long_obj.j()?; Ok(long_value as u64) @@ -273,9 +253,7 @@ impl JNIEnvExt for JNIEnv<'_> { } fn get_bytes_opt(&mut self, obj: &JObject) -> Result> { - self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_byte_buffer_obj = java_obj_gen.l()?; + self.get_optional(obj, |env, java_byte_buffer_obj| { let j_byte_buffer = JByteBuffer::from(java_byte_buffer_obj); let raw_data = env.get_direct_buffer_address(&j_byte_buffer)?; let capacity = env.get_direct_buffer_capacity(&j_byte_buffer)?; @@ -288,10 +266,7 @@ impl JNIEnvExt for JNIEnv<'_> { where F: Fn(&mut JNIEnv, &JObject) -> Result, { - self.get_optional(obj, |env, opt_obj| { - let list_obj = env - .call_method(opt_obj, "get", "()Ljava/lang/Object;", &[])? - .l()?; + self.get_optional(obj, |env, list_obj| { let list = env.get_list(&list_obj)?; let mut iter = list.iter(env)?; let mut items: Vec = Vec::with_capacity(list.size(env)? as usize); @@ -369,24 +344,12 @@ impl JNIEnvExt for JNIEnv<'_> { T: TryFrom, >::Error: std::fmt::Debug, { - let java_object = self - .call_method(obj, method_name, "()Ljava/util/Optional;", &[])? - .l()?; - let rust_obj = if self - .call_method(&java_object, "isPresent", "()Z", &[])? - .z()? - { - let inner_jobj = self - .call_method(&java_object, "get", "()Ljava/lang/Object;", &[])? - .l()?; - let inner_value = self.call_method(&inner_jobj, "intValue", "()I", &[])?.i()?; - Some(T::try_from(inner_value).map_err(|e| { + self.get_optional_from_method(obj, method_name, |env, inner_jobj| { + let inner_value = env.call_method(&inner_jobj, "intValue", "()I", &[])?.i()?; + T::try_from(inner_value).map_err(|e| { Error::io_error(format!("Failed to convert from i32 to rust type: {:?}", e)) - })?) - } else { - None - }; - Ok(rust_obj) + }) + }) } fn get_optional_i64_from_method( @@ -414,26 +377,12 @@ impl JNIEnvExt for JNIEnv<'_> { T: TryFrom, >::Error: std::fmt::Debug, { - let java_object = self - .call_method(obj, method_name, "()Ljava/util/Optional;", &[])? - .l()?; - let rust_obj = if self - .call_method(&java_object, "isPresent", "()Z", &[])? - .z()? - { - let inner_jobj = self - .call_method(&java_object, "get", "()Ljava/lang/Object;", &[])? - .l()?; - let inner_value = self - .call_method(&inner_jobj, "longValue", "()J", &[])? - .j()?; - Some(T::try_from(inner_value).map_err(|e| { + self.get_optional_from_method(obj, method_name, |env, inner_jobj| { + let inner_value = env.call_method(&inner_jobj, "longValue", "()J", &[])?.j()?; + T::try_from(inner_value).map_err(|e| { Error::io_error(format!("Failed to convert from i32 to rust type: {:?}", e)) - })?) - } else { - None - }; - Ok(rust_obj) + }) + }) } fn get_optional_string_from_method( @@ -460,30 +409,22 @@ impl JNIEnvExt for JNIEnv<'_> { let optional_obj = self .call_method(obj, method_name, "()Ljava/util/Optional;", &[])? .l()?; - - if self - .call_method(&optional_obj, "isPresent", "()Z", &[])? - .z()? - { - let inner_obj = self - .call_method(&optional_obj, "get", "()Ljava/lang/Object;", &[])? - .l()?; - f(self, inner_obj).map(Some) - } else { - Ok(None) - } + self.get_optional(&optional_obj, f) } fn get_optional(&mut self, obj: &JObject, f: F) -> Result> where - F: FnOnce(&mut JNIEnv, &JObject) -> Result, + F: FnOnce(&mut JNIEnv, JObject) -> Result, { if obj.is_null() { return Ok(None); } let is_present = self.call_method(obj, "isPresent", "()Z", &[])?; if is_present.z()? { - f(self, obj).map(Some) + let inner_obj = self + .call_method(obj, "get", "()Ljava/lang/Object;", &[])? + .l()?; + f(self, inner_obj).map(Some) } else { // TODO(lu): put get java object into here cuz can only get java Object Ok(None) diff --git a/java/lance-jni/src/fragment.rs b/java/lance-jni/src/fragment.rs index 1f6ec0f85fd..c3caec67d82 100644 --- a/java/lance-jni/src/fragment.rs +++ b/java/lance-jni/src/fragment.rs @@ -745,19 +745,7 @@ impl FromJObjectWithEnv for JObject<'_> { } fn get_base_id(env: &mut JNIEnv, obj: &JObject) -> Result> { - let base_id = env - .call_method(obj, "getBaseId", "()Ljava/util/Optional;", &[])? - .l()?; - - if env.call_method(&base_id, "isPresent", "()Z", &[])?.z()? { - let inner_value = env - .call_method(&base_id, "get", "()Ljava/lang/Object;", &[])? - .l()?; - let int_value = env.call_method(&inner_value, "intValue", "()I", &[])?.i()?; - Ok(Some(int_value as u32)) - } else { - Ok(None) - } + env.get_optional_u32_from_method(obj, "getBaseId") } fn convert_to_java_integer<'local>( diff --git a/java/lance-jni/src/utils.rs b/java/lance-jni/src/utils.rs index 336d0a4ac48..5cb55c200e1 100644 --- a/java/lance-jni/src/utils.rs +++ b/java/lance-jni/src/utils.rs @@ -79,12 +79,10 @@ pub fn extract_write_params( extract_storage_options(env, storage_options_obj)?; // Extract storage options provider if present - let storage_options_provider = env.get_optional(storage_options_provider_obj, |env, obj| { - let provider_obj = env - .call_method(obj, "get", "()Ljava/lang/Object;", &[])? - .l()?; - JavaStorageOptionsProvider::new(env, provider_obj) - })?; + let storage_options_provider = env + .get_optional(storage_options_provider_obj, |env, provider_obj| { + JavaStorageOptionsProvider::new(env, provider_obj) + })?; let storage_options_provider_arc: Option> = storage_options_provider.map(|v| Arc::new(v) as Arc); @@ -160,10 +158,7 @@ pub fn build_compaction_options( // Convert from Java Optional to Rust Option pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result> { - let query = env.get_optional(&query_obj, |env, obj| { - let java_obj_gen = env.call_method(obj, "get", "()Ljava/lang/Object;", &[])?; - let java_obj = java_obj_gen.l()?; - + let query = env.get_optional(&query_obj, |env, java_obj| { let column = env.get_string_from_method(&java_obj, "getColumn")?; let key_array = env.get_vec_f32_from_method(&java_obj, "getKey")?; let key = Arc::new(Float32Array::from(key_array)); @@ -208,151 +203,134 @@ pub fn get_vector_index_params( env: &mut JNIEnv, index_params_obj: JObject, ) -> Result> { - let vector_index_params_option_object = env - .call_method( - index_params_obj, - "getVectorIndexParams", - "()Ljava/util/Optional;", - &[], - )? - .l()?; - - let vector_index_params_option = if env - .call_method(&vector_index_params_option_object, "isPresent", "()Z", &[])? - .z()? - { - let vector_index_params_obj = env - .call_method( - &vector_index_params_option_object, - "get", - "()Ljava/lang/Object;", - &[], - )? - .l()?; - - // Get distance type from VectorIndexParams - let distance_type_obj: JString = env - .call_method( + let vector_index_params_option = env.get_optional_from_method( + &index_params_obj, + "getVectorIndexParams", + |env, vector_index_params_obj| { + // Get distance type from VectorIndexParams + let distance_type_obj: JString = env + .call_method( + &vector_index_params_obj, + "getDistanceTypeString", + "()Ljava/lang/String;", + &[], + )? + .l()? + .into(); + let distance_type_str: String = env.get_string(&distance_type_obj)?.into(); + let distance_type = DistanceType::try_from(distance_type_str.as_str())?; + + let ivf_params_obj = env + .call_method( + &vector_index_params_obj, + "getIvfParams", + "()Lorg/lance/index/vector/IvfBuildParams;", + &[], + )? + .l()?; + + let mut stages = Vec::new(); + + // Parse IvfBuildParams + let num_partitions = + env.get_int_as_usize_from_method(&ivf_params_obj, "getNumPartitions")?; + let max_iters = env.get_int_as_usize_from_method(&ivf_params_obj, "getMaxIters")?; + let sample_rate = env.get_int_as_usize_from_method(&ivf_params_obj, "getSampleRate")?; + let shuffle_partition_batches = + env.get_int_as_usize_from_method(&ivf_params_obj, "getShufflePartitionBatches")?; + let shuffle_partition_concurrency = env + .get_int_as_usize_from_method(&ivf_params_obj, "getShufflePartitionConcurrency")?; + + let ivf_params = IvfBuildParams { + num_partitions: Some(num_partitions), + max_iters, + sample_rate, + shuffle_partition_batches, + shuffle_partition_concurrency, + ..Default::default() + }; + stages.push(StageParams::Ivf(ivf_params)); + + // Parse HnswBuildParams + let hnsw_params = env.get_optional_from_method( &vector_index_params_obj, - "getDistanceTypeString", - "()Ljava/lang/String;", - &[], - )? - .l()? - .into(); - let distance_type_str: String = env.get_string(&distance_type_obj)?.into(); - let distance_type = DistanceType::try_from(distance_type_str.as_str())?; - - let ivf_params_obj = env - .call_method( + "getHnswParams", + |env, hnsw_obj| { + let max_level = + env.call_method(&hnsw_obj, "getMaxLevel", "()S", &[])?.s()? as u16; + let m = env.get_int_as_usize_from_method(&hnsw_obj, "getM")?; + let ef_construction = + env.get_int_as_usize_from_method(&hnsw_obj, "getEfConstruction")?; + let prefetch_distance = + env.get_optional_usize_from_method(&hnsw_obj, "getPrefetchDistance")?; + + Ok(HnswBuildParams { + max_level, + m, + ef_construction, + prefetch_distance, + }) + }, + )?; + + if let Some(hnsw_params) = hnsw_params { + stages.push(StageParams::Hnsw(hnsw_params)); + } + + // Parse PQBuildParams + let pq_params = env.get_optional_from_method( &vector_index_params_obj, - "getIvfParams", - "()Lorg/lance/index/vector/IvfBuildParams;", - &[], - )? - .l()?; - - let mut stages = Vec::new(); - - // Parse IvfBuildParams - let num_partitions = - env.get_int_as_usize_from_method(&ivf_params_obj, "getNumPartitions")?; - let max_iters = env.get_int_as_usize_from_method(&ivf_params_obj, "getMaxIters")?; - let sample_rate = env.get_int_as_usize_from_method(&ivf_params_obj, "getSampleRate")?; - let shuffle_partition_batches = - env.get_int_as_usize_from_method(&ivf_params_obj, "getShufflePartitionBatches")?; - let shuffle_partition_concurrency = - env.get_int_as_usize_from_method(&ivf_params_obj, "getShufflePartitionConcurrency")?; - - let ivf_params = IvfBuildParams { - num_partitions: Some(num_partitions), - max_iters, - sample_rate, - shuffle_partition_batches, - shuffle_partition_concurrency, - ..Default::default() - }; - stages.push(StageParams::Ivf(ivf_params)); - - // Parse HnswBuildParams - let hnsw_params = env.get_optional_from_method( - &vector_index_params_obj, - "getHnswParams", - |env, hnsw_obj| { - let max_level = env.call_method(&hnsw_obj, "getMaxLevel", "()S", &[])?.s()? as u16; - let m = env.get_int_as_usize_from_method(&hnsw_obj, "getM")?; - let ef_construction = - env.get_int_as_usize_from_method(&hnsw_obj, "getEfConstruction")?; - let prefetch_distance = - env.get_optional_usize_from_method(&hnsw_obj, "getPrefetchDistance")?; - - Ok(HnswBuildParams { - max_level, - m, - ef_construction, - prefetch_distance, - }) - }, - )?; - - if let Some(hnsw_params) = hnsw_params { - stages.push(StageParams::Hnsw(hnsw_params)); - } - - // Parse PQBuildParams - let pq_params = env.get_optional_from_method( - &vector_index_params_obj, - "getPqParams", - |env, pq_obj| { - let num_sub_vectors = - env.get_int_as_usize_from_method(&pq_obj, "getNumSubVectors")?; - let num_bits = env.get_int_as_usize_from_method(&pq_obj, "getNumBits")?; - let max_iters = env.get_int_as_usize_from_method(&pq_obj, "getMaxIters")?; - let kmeans_redos = env.get_int_as_usize_from_method(&pq_obj, "getKmeansRedos")?; - let sample_rate = env.get_int_as_usize_from_method(&pq_obj, "getSampleRate")?; - - Ok(PQBuildParams { - num_sub_vectors, - num_bits, - max_iters, - kmeans_redos, - sample_rate, - ..Default::default() - }) - }, - )?; - - if let Some(pq_params) = pq_params { - stages.push(StageParams::PQ(pq_params)); - } - - // Parse SQBuildParams - let sq_params = env.get_optional_from_method( - &vector_index_params_obj, - "getSqParams", - |env, sq_obj| { - let num_bits = env.call_method(&sq_obj, "getNumBits", "()S", &[])?.s()? as u16; - let sample_rate = env.get_int_as_usize_from_method(&sq_obj, "getSampleRate")?; - - Ok(SQBuildParams { - num_bits, - sample_rate, - }) - }, - )?; - - if let Some(sq_params) = sq_params { - stages.push(StageParams::SQ(sq_params)); - } - - Some(VectorIndexParams { - metric_type: distance_type, - stages, - version: IndexFileVersion::V3, - }) - } else { - None - }; + "getPqParams", + |env, pq_obj| { + let num_sub_vectors = + env.get_int_as_usize_from_method(&pq_obj, "getNumSubVectors")?; + let num_bits = env.get_int_as_usize_from_method(&pq_obj, "getNumBits")?; + let max_iters = env.get_int_as_usize_from_method(&pq_obj, "getMaxIters")?; + let kmeans_redos = + env.get_int_as_usize_from_method(&pq_obj, "getKmeansRedos")?; + let sample_rate = env.get_int_as_usize_from_method(&pq_obj, "getSampleRate")?; + + Ok(PQBuildParams { + num_sub_vectors, + num_bits, + max_iters, + kmeans_redos, + sample_rate, + ..Default::default() + }) + }, + )?; + + if let Some(pq_params) = pq_params { + stages.push(StageParams::PQ(pq_params)); + } + + // Parse SQBuildParams + let sq_params = env.get_optional_from_method( + &vector_index_params_obj, + "getSqParams", + |env, sq_obj| { + let num_bits = env.call_method(&sq_obj, "getNumBits", "()S", &[])?.s()? as u16; + let sample_rate = env.get_int_as_usize_from_method(&sq_obj, "getSampleRate")?; + + Ok(SQBuildParams { + num_bits, + sample_rate, + }) + }, + )?; + + if let Some(sq_params) = sq_params { + stages.push(StageParams::SQ(sq_params)); + } + + Ok(VectorIndexParams { + metric_type: distance_type, + stages, + version: IndexFileVersion::V3, + }) + }, + )?; match vector_index_params_option { Some(params) => Ok(Box::new(params) as Box), @@ -366,46 +344,26 @@ pub fn get_scalar_index_params( env: &mut JNIEnv, index_params_obj: JObject, ) -> Result<(String, Option)> { - let scalar_params_option_object = env - .call_method( - index_params_obj, - "getScalarIndexParams", - "()Ljava/util/Optional;", - &[], - )? - .l()?; - - if env - .call_method(&scalar_params_option_object, "isPresent", "()Z", &[])? - .z()? - { - let scalar_params_obj = env - .call_method( - &scalar_params_option_object, - "get", - "()Ljava/lang/Object;", - &[], - )? - .l()?; - - let index_type = env.get_string_from_method(&scalar_params_obj, "getIndexType")?; - - let params = env.get_optional_from_method( - &scalar_params_obj, - "getJsonParams", - |env, params_obj| { - let params_str: JString = params_obj.into(); - let params_string: String = env.get_string(¶ms_str)?.into(); - Ok(params_string) - }, - )?; - - Ok((index_type, params)) - } else { - Err(Error::input_error( - "ScalarIndexParams not present".to_string(), - )) - } + env.get_optional_from_method( + &index_params_obj, + "getScalarIndexParams", + |env, scalar_params_obj| { + let index_type = env.get_string_from_method(&scalar_params_obj, "getIndexType")?; + + let params = env.get_optional_from_method( + &scalar_params_obj, + "getJsonParams", + |env, params_obj| { + let params_str: JString = params_obj.into(); + let params_string: String = env.get_string(¶ms_str)?.into(); + Ok(params_string) + }, + )?; + + Ok((index_type, params)) + }, + )? + .ok_or_else(|| Error::input_error("ScalarIndexParams not present".to_string())) } pub fn to_rust_map(env: &mut JNIEnv, jmap: &JMap) -> Result> {