From 26b0a74d50d8a069828e3982a9e0b85d3a04e290 Mon Sep 17 00:00:00 2001 From: "marco.mengelkoch" Date: Mon, 1 Dec 2025 19:08:02 +0100 Subject: [PATCH 1/4] add get feature names --- Cargo.toml | 3 ++ src/model.rs | 128 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 6c6d3c4..6e657cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,9 @@ categories = ["science"] readme = "README.md" rust-version = "1.85" +[dependencies] +libc = "0.2.177" + [build-dependencies] bindgen = "0.72.0" ureq = "2.0" diff --git a/src/model.rs b/src/model.rs index eb4a76b..11348f2 100644 --- a/src/model.rs +++ b/src/model.rs @@ -275,6 +275,134 @@ impl Model { }) } + fn get_feature_indices_from_c( + indices_ptr: *mut usize, + count: usize, + err_msg: &str, + ) -> CatBoostResult> { + if indices_ptr.is_null() { + if count == 0 { + return Ok(Vec::new()); + } + return Err(CatBoostError { + description: err_msg.to_owned(), + }); + } + let mut indices = Vec::with_capacity(count); + for i in 0..count { + indices.push(unsafe { *indices_ptr.add(i) }); + } + unsafe { libc::free(indices_ptr as *mut _) }; + Ok(indices) + } + + fn get_feature_names_from_c( + names_ptr: *mut *mut std::ffi::c_char, + count: usize, + err_msg: &str, + ) -> CatBoostResult> { + if names_ptr.is_null() { + if count == 0 { + return Ok(Vec::new()); + } + return Err(CatBoostError { + description: err_msg.to_owned(), + }); + } + let mut names = Vec::with_capacity(count); + for i in 0..count { + let ptr = unsafe { *names_ptr.add(i) }; + let s = unsafe { CStr::from_ptr(ptr) } + .to_string_lossy() + .into_owned(); + names.push(s); + unsafe { libc::free(ptr as *mut _) }; + } + unsafe { libc::free(names_ptr as *mut _) }; + Ok(names) + } + + fn get_specific_feature_names( + &self, + indices_fn: unsafe extern "C" fn( + *mut sys::ModelCalcerHandle, + *mut *mut usize, + *mut usize, + ) -> bool, + err_msg: &str, + ) -> CatBoostResult> { + let all_names = self.get_feature_names()?; + let indices = self.get_feature_indices(indices_fn, err_msg)?; + Ok(indices.into_iter().map(|i| all_names[i].clone()).collect()) + } + + /// Get names of features used in model + pub fn get_feature_names(&self) -> CatBoostResult> { + unsafe { + let mut names_ptr: *mut *mut std::ffi::c_char = std::ptr::null_mut(); + let mut count: usize = 0; + + let ok = + sys::GetModelUsedFeaturesNames(self.handle, &mut names_ptr, &mut count); + CatBoostError::check_return_value(ok)?; + + Self::get_feature_names_from_c( + names_ptr, + count, + "GetModelUsedFeaturesNames returned null pointer", + ) + } + } + + fn get_feature_indices( + &self, + indices_fn: unsafe extern "C" fn( + *mut sys::ModelCalcerHandle, + *mut *mut usize, + *mut usize, + ) -> bool, + err_msg: &str, + ) -> CatBoostResult> { + let mut indices_ptr: *mut usize = std::ptr::null_mut(); + let mut count: usize = 0; + CatBoostError::check_return_value(unsafe { + indices_fn(self.handle, &mut indices_ptr, &mut count) + })?; + Self::get_feature_indices_from_c(indices_ptr, count, err_msg) + } + + /// Get names of float features used in model + pub fn get_float_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetFloatFeatureIndices, + "GetFloatFeatureIndices returned null pointer", + ) + } + + /// Get names of cat features used in model + pub fn get_cat_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetCatFeatureIndices, + "GetCatFeatureIndices returned null pointer", + ) + } + + /// Get names of text features used in model + pub fn get_text_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetTextFeatureIndices, + "GetTextFeatureIndices returned null pointer", + ) + } + + /// Get names of embedding features used in model + pub fn get_embedding_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetEmbeddingFeatureIndices, + "GetEmbeddingFeatureIndices returned null pointer", + ) + } + /// Get expected float feature count for model pub fn get_float_features_count(&self) -> usize { unsafe { sys::GetFloatFeaturesCount(self.handle) } From 5d9e87aa2b442239e2ce7ff36e322e2eab8885f3 Mon Sep 17 00:00:00 2001 From: "marco.mengelkoch" Date: Mon, 1 Dec 2025 19:18:16 +0100 Subject: [PATCH 2/4] fmt --- src/model.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/model.rs b/src/model.rs index 11348f2..ab65109 100644 --- a/src/model.rs +++ b/src/model.rs @@ -335,15 +335,14 @@ impl Model { let indices = self.get_feature_indices(indices_fn, err_msg)?; Ok(indices.into_iter().map(|i| all_names[i].clone()).collect()) } - + /// Get names of features used in model pub fn get_feature_names(&self) -> CatBoostResult> { unsafe { let mut names_ptr: *mut *mut std::ffi::c_char = std::ptr::null_mut(); let mut count: usize = 0; - let ok = - sys::GetModelUsedFeaturesNames(self.handle, &mut names_ptr, &mut count); + let ok = sys::GetModelUsedFeaturesNames(self.handle, &mut names_ptr, &mut count); CatBoostError::check_return_value(ok)?; Self::get_feature_names_from_c( From 68ce50ed3c8c78274eca6bf67f5e0e221c696f7a Mon Sep 17 00:00:00 2001 From: "marco.mengelkoch" Date: Tue, 2 Dec 2025 18:17:23 +0100 Subject: [PATCH 3/4] error handling and specific cfg --- src/model.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/model.rs b/src/model.rs index ab65109..a6445ba 100644 --- a/src/model.rs +++ b/src/model.rs @@ -322,6 +322,8 @@ impl Model { Ok(names) } + /// Get names of specific type of features used in model, + /// returns error if index out of bounds fn get_specific_feature_names( &self, indices_fn: unsafe extern "C" fn( @@ -333,10 +335,21 @@ impl Model { ) -> CatBoostResult> { let all_names = self.get_feature_names()?; let indices = self.get_feature_indices(indices_fn, err_msg)?; - Ok(indices.into_iter().map(|i| all_names[i].clone()).collect()) + indices + .into_iter() + .map(|i| { + all_names + .get(i) + .ok_or_else(|| CatBoostError { + description: format!("feature index {} out of bounds", i), + }) + .map(|s| s.clone()) + }) + .collect() } /// Get names of features used in model + #[cfg(catboost_feature_indices)] pub fn get_feature_names(&self) -> CatBoostResult> { unsafe { let mut names_ptr: *mut *mut std::ffi::c_char = std::ptr::null_mut(); @@ -371,6 +384,7 @@ impl Model { } /// Get names of float features used in model + #[cfg(catboost_feature_indices)] pub fn get_float_feature_names(&self) -> CatBoostResult> { self.get_specific_feature_names( sys::GetFloatFeatureIndices, @@ -379,6 +393,7 @@ impl Model { } /// Get names of cat features used in model + #[cfg(catboost_feature_indices)] pub fn get_cat_feature_names(&self) -> CatBoostResult> { self.get_specific_feature_names( sys::GetCatFeatureIndices, @@ -387,6 +402,7 @@ impl Model { } /// Get names of text features used in model + #[cfg(catboost_feature_indices)] pub fn get_text_feature_names(&self) -> CatBoostResult> { self.get_specific_feature_names( sys::GetTextFeatureIndices, @@ -395,6 +411,7 @@ impl Model { } /// Get names of embedding features used in model + #[cfg(catboost_feature_indices)] pub fn get_embedding_feature_names(&self) -> CatBoostResult> { self.get_specific_feature_names( sys::GetEmbeddingFeatureIndices, From 2831014c720b1e6c6fa7dcdd9955d5076b817d87 Mon Sep 17 00:00:00 2001 From: "marco.mengelkoch" Date: Mon, 15 Dec 2025 22:48:39 +0100 Subject: [PATCH 4/4] add safety comments --- Cargo.toml | 2 +- src/model.rs | 43 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6e657cf..3328b39 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ readme = "README.md" rust-version = "1.85" [dependencies] -libc = "0.2.177" +libc = "0.2.178" [build-dependencies] bindgen = "0.72.0" diff --git a/src/model.rs b/src/model.rs index a6445ba..a8a75a4 100644 --- a/src/model.rs +++ b/src/model.rs @@ -275,6 +275,34 @@ impl Model { }) } + /// # Safety + /// + /// This function is unsafe because it dereferences a raw pointer and assumes a memory + /// allocation contract with an external C API. + /// + /// - `ptr` must be a valid pointer to a C-allocated buffer containing `count` elements of type `T`, + /// or it must be a null pointer if `count` is 0. + /// - The buffer must have been allocated by a `malloc`-compatible allocator, as the CatBoost C API + /// documentation for functions like `GetFloatFeatureIndices` and `GetModelUsedFeaturesNames` + /// stipulates that the caller is responsible for freeing the returned buffer. The standard C + /// mechanism for this is `free()`. + /// (Source: https://github.com/catboost/catboost/blob/master/catboost/libs/model_interface/c_api.h) + /// + /// This function takes ownership of the buffer and frees it with `libc::free` after copying + /// the data into a Rust `Vec`. + unsafe fn from_c_allocated_buffer(ptr: *mut T, count: usize) -> Vec { + if ptr.is_null() { + return Vec::new(); + } + let mut result = Vec::with_capacity(count); + for i in 0..count { + result.push(unsafe { *ptr.add(i) }); + } + unsafe { libc::free(ptr as *mut _) }; + result + } + + /// Converts a C-style array of feature indices into a `Vec`, freeing the C buffer. fn get_feature_indices_from_c( indices_ptr: *mut usize, count: usize, @@ -288,14 +316,14 @@ impl Model { description: err_msg.to_owned(), }); } - let mut indices = Vec::with_capacity(count); - for i in 0..count { - indices.push(unsafe { *indices_ptr.add(i) }); - } - unsafe { libc::free(indices_ptr as *mut _) }; + // SAFETY: The contract for CatBoost functions like `GetFloatFeatureIndices` is that they + // return a `malloc`-allocated buffer that the caller must free. `from_c_allocated_buffer` + // upholds this contract by copying the data and then calling `libc::free`. + let indices = unsafe { Self::from_c_allocated_buffer(indices_ptr, count) }; Ok(indices) } + /// Converts a C-style array of C strings into a `Vec`, freeing all associated C memory. fn get_feature_names_from_c( names_ptr: *mut *mut std::ffi::c_char, count: usize, @@ -309,6 +337,9 @@ impl Model { description: err_msg.to_owned(), }); } + // SAFETY: The contract for `GetModelUsedFeaturesNames` is that it returns a `malloc`-allocated + // array of `malloc`-allocated strings. The caller must free both the outer array and each + // inner string pointer. This block upholds that contract. let mut names = Vec::with_capacity(count); for i in 0..count { let ptr = unsafe { *names_ptr.add(i) }; @@ -322,7 +353,7 @@ impl Model { Ok(names) } - /// Get names of specific type of features used in model, + /// Get names of specific type of features used in model, /// returns error if index out of bounds fn get_specific_feature_names( &self,