diff --git a/Cargo.toml b/Cargo.toml index 6c6d3c4..3328b39 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,9 @@ categories = ["science"] readme = "README.md" rust-version = "1.85" +[dependencies] +libc = "0.2.178" + [build-dependencies] bindgen = "0.72.0" ureq = "2.0" diff --git a/src/model.rs b/src/model.rs index eb4a76b..a8a75a4 100644 --- a/src/model.rs +++ b/src/model.rs @@ -275,6 +275,181 @@ impl Model { }) } + /// # Safety + /// + /// This function is unsafe because it dereferences a raw pointer and assumes a memory + /// allocation contract with an external C API. + /// + /// - `ptr` must be a valid pointer to a C-allocated buffer containing `count` elements of type `T`, + /// or it must be a null pointer if `count` is 0. + /// - The buffer must have been allocated by a `malloc`-compatible allocator, as the CatBoost C API + /// documentation for functions like `GetFloatFeatureIndices` and `GetModelUsedFeaturesNames` + /// stipulates that the caller is responsible for freeing the returned buffer. The standard C + /// mechanism for this is `free()`. + /// (Source: https://github.com/catboost/catboost/blob/master/catboost/libs/model_interface/c_api.h) + /// + /// This function takes ownership of the buffer and frees it with `libc::free` after copying + /// the data into a Rust `Vec`. + unsafe fn from_c_allocated_buffer(ptr: *mut T, count: usize) -> Vec { + if ptr.is_null() { + return Vec::new(); + } + let mut result = Vec::with_capacity(count); + for i in 0..count { + result.push(unsafe { *ptr.add(i) }); + } + unsafe { libc::free(ptr as *mut _) }; + result + } + + /// Converts a C-style array of feature indices into a `Vec`, freeing the C buffer. + fn get_feature_indices_from_c( + indices_ptr: *mut usize, + count: usize, + err_msg: &str, + ) -> CatBoostResult> { + if indices_ptr.is_null() { + if count == 0 { + return Ok(Vec::new()); + } + return Err(CatBoostError { + description: err_msg.to_owned(), + }); + } + // SAFETY: The contract for CatBoost functions like `GetFloatFeatureIndices` is that they + // return a `malloc`-allocated buffer that the caller must free. `from_c_allocated_buffer` + // upholds this contract by copying the data and then calling `libc::free`. + let indices = unsafe { Self::from_c_allocated_buffer(indices_ptr, count) }; + Ok(indices) + } + + /// Converts a C-style array of C strings into a `Vec`, freeing all associated C memory. + fn get_feature_names_from_c( + names_ptr: *mut *mut std::ffi::c_char, + count: usize, + err_msg: &str, + ) -> CatBoostResult> { + if names_ptr.is_null() { + if count == 0 { + return Ok(Vec::new()); + } + return Err(CatBoostError { + description: err_msg.to_owned(), + }); + } + // SAFETY: The contract for `GetModelUsedFeaturesNames` is that it returns a `malloc`-allocated + // array of `malloc`-allocated strings. The caller must free both the outer array and each + // inner string pointer. This block upholds that contract. + let mut names = Vec::with_capacity(count); + for i in 0..count { + let ptr = unsafe { *names_ptr.add(i) }; + let s = unsafe { CStr::from_ptr(ptr) } + .to_string_lossy() + .into_owned(); + names.push(s); + unsafe { libc::free(ptr as *mut _) }; + } + unsafe { libc::free(names_ptr as *mut _) }; + Ok(names) + } + + /// Get names of specific type of features used in model, + /// returns error if index out of bounds + fn get_specific_feature_names( + &self, + indices_fn: unsafe extern "C" fn( + *mut sys::ModelCalcerHandle, + *mut *mut usize, + *mut usize, + ) -> bool, + err_msg: &str, + ) -> CatBoostResult> { + let all_names = self.get_feature_names()?; + let indices = self.get_feature_indices(indices_fn, err_msg)?; + indices + .into_iter() + .map(|i| { + all_names + .get(i) + .ok_or_else(|| CatBoostError { + description: format!("feature index {} out of bounds", i), + }) + .map(|s| s.clone()) + }) + .collect() + } + + /// Get names of features used in model + #[cfg(catboost_feature_indices)] + pub fn get_feature_names(&self) -> CatBoostResult> { + unsafe { + let mut names_ptr: *mut *mut std::ffi::c_char = std::ptr::null_mut(); + let mut count: usize = 0; + + let ok = sys::GetModelUsedFeaturesNames(self.handle, &mut names_ptr, &mut count); + CatBoostError::check_return_value(ok)?; + + Self::get_feature_names_from_c( + names_ptr, + count, + "GetModelUsedFeaturesNames returned null pointer", + ) + } + } + + fn get_feature_indices( + &self, + indices_fn: unsafe extern "C" fn( + *mut sys::ModelCalcerHandle, + *mut *mut usize, + *mut usize, + ) -> bool, + err_msg: &str, + ) -> CatBoostResult> { + let mut indices_ptr: *mut usize = std::ptr::null_mut(); + let mut count: usize = 0; + CatBoostError::check_return_value(unsafe { + indices_fn(self.handle, &mut indices_ptr, &mut count) + })?; + Self::get_feature_indices_from_c(indices_ptr, count, err_msg) + } + + /// Get names of float features used in model + #[cfg(catboost_feature_indices)] + pub fn get_float_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetFloatFeatureIndices, + "GetFloatFeatureIndices returned null pointer", + ) + } + + /// Get names of cat features used in model + #[cfg(catboost_feature_indices)] + pub fn get_cat_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetCatFeatureIndices, + "GetCatFeatureIndices returned null pointer", + ) + } + + /// Get names of text features used in model + #[cfg(catboost_feature_indices)] + pub fn get_text_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetTextFeatureIndices, + "GetTextFeatureIndices returned null pointer", + ) + } + + /// Get names of embedding features used in model + #[cfg(catboost_feature_indices)] + pub fn get_embedding_feature_names(&self) -> CatBoostResult> { + self.get_specific_feature_names( + sys::GetEmbeddingFeatureIndices, + "GetEmbeddingFeatureIndices returned null pointer", + ) + } + /// Get expected float feature count for model pub fn get_float_features_count(&self) -> usize { unsafe { sys::GetFloatFeaturesCount(self.handle) }