Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ categories = ["science"]
readme = "README.md"
rust-version = "1.85"

[dependencies]
libc = "0.2.178"

[build-dependencies]
bindgen = "0.72.0"
ureq = "2.0"
Expand Down
175 changes: 175 additions & 0 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,181 @@ impl Model {
})
}

/// # Safety
///
/// This function is unsafe because it dereferences a raw pointer and assumes a memory
/// allocation contract with an external C API.
///
/// - `ptr` must be a valid pointer to a C-allocated buffer containing `count` elements of type `T`,
/// or it must be a null pointer if `count` is 0.
/// - The buffer must have been allocated by a `malloc`-compatible allocator, as the CatBoost C API
/// documentation for functions like `GetFloatFeatureIndices` and `GetModelUsedFeaturesNames`
/// stipulates that the caller is responsible for freeing the returned buffer. The standard C
/// mechanism for this is `free()`.
/// (Source: https://github.com/catboost/catboost/blob/master/catboost/libs/model_interface/c_api.h)
///
/// This function takes ownership of the buffer and frees it with `libc::free` after copying
/// the data into a Rust `Vec`.
unsafe fn from_c_allocated_buffer<T: Copy>(ptr: *mut T, count: usize) -> Vec<T> {
if ptr.is_null() {
return Vec::new();
}
let mut result = Vec::with_capacity(count);
for i in 0..count {
result.push(unsafe { *ptr.add(i) });
}
unsafe { libc::free(ptr as *mut _) };
result
}

/// Converts a C-style array of feature indices into a `Vec<usize>`, freeing the C buffer.
fn get_feature_indices_from_c(
indices_ptr: *mut usize,
count: usize,
err_msg: &str,
) -> CatBoostResult<Vec<usize>> {
if indices_ptr.is_null() {
if count == 0 {
return Ok(Vec::new());
}
return Err(CatBoostError {
description: err_msg.to_owned(),
});
}
// SAFETY: The contract for CatBoost functions like `GetFloatFeatureIndices` is that they
// return a `malloc`-allocated buffer that the caller must free. `from_c_allocated_buffer`
// upholds this contract by copying the data and then calling `libc::free`.
let indices = unsafe { Self::from_c_allocated_buffer(indices_ptr, count) };
Ok(indices)
}

/// Converts a C-style array of C strings into a `Vec<String>`, freeing all associated C memory.
fn get_feature_names_from_c(
names_ptr: *mut *mut std::ffi::c_char,
count: usize,
err_msg: &str,
) -> CatBoostResult<Vec<String>> {
if names_ptr.is_null() {
if count == 0 {
return Ok(Vec::new());
}
return Err(CatBoostError {
description: err_msg.to_owned(),
});
}
// SAFETY: The contract for `GetModelUsedFeaturesNames` is that it returns a `malloc`-allocated
// array of `malloc`-allocated strings. The caller must free both the outer array and each
// inner string pointer. This block upholds that contract.
let mut names = Vec::with_capacity(count);
for i in 0..count {
let ptr = unsafe { *names_ptr.add(i) };
let s = unsafe { CStr::from_ptr(ptr) }
.to_string_lossy()
.into_owned();
names.push(s);
unsafe { libc::free(ptr as *mut _) };
}
unsafe { libc::free(names_ptr as *mut _) };
Ok(names)
}

/// Get names of specific type of features used in model,
/// returns error if index out of bounds
fn get_specific_feature_names(
&self,
indices_fn: unsafe extern "C" fn(
*mut sys::ModelCalcerHandle,
*mut *mut usize,
*mut usize,
) -> bool,
err_msg: &str,
) -> CatBoostResult<Vec<String>> {
let all_names = self.get_feature_names()?;
let indices = self.get_feature_indices(indices_fn, err_msg)?;
indices
.into_iter()
.map(|i| {
all_names
.get(i)
.ok_or_else(|| CatBoostError {
description: format!("feature index {} out of bounds", i),
})
.map(|s| s.clone())
})
.collect()
}

/// Get names of features used in model
#[cfg(catboost_feature_indices)]
pub fn get_feature_names(&self) -> CatBoostResult<Vec<String>> {
unsafe {
let mut names_ptr: *mut *mut std::ffi::c_char = std::ptr::null_mut();
let mut count: usize = 0;

let ok = sys::GetModelUsedFeaturesNames(self.handle, &mut names_ptr, &mut count);
CatBoostError::check_return_value(ok)?;

Self::get_feature_names_from_c(
names_ptr,
count,
"GetModelUsedFeaturesNames returned null pointer",
)
}
}

fn get_feature_indices(
&self,
indices_fn: unsafe extern "C" fn(
*mut sys::ModelCalcerHandle,
*mut *mut usize,
*mut usize,
) -> bool,
err_msg: &str,
) -> CatBoostResult<Vec<usize>> {
let mut indices_ptr: *mut usize = std::ptr::null_mut();
let mut count: usize = 0;
CatBoostError::check_return_value(unsafe {
indices_fn(self.handle, &mut indices_ptr, &mut count)
})?;
Self::get_feature_indices_from_c(indices_ptr, count, err_msg)
}

/// Get names of float features used in model
#[cfg(catboost_feature_indices)]
pub fn get_float_feature_names(&self) -> CatBoostResult<Vec<String>> {
self.get_specific_feature_names(
sys::GetFloatFeatureIndices,
"GetFloatFeatureIndices returned null pointer",
)
}

/// Get names of cat features used in model
#[cfg(catboost_feature_indices)]
pub fn get_cat_feature_names(&self) -> CatBoostResult<Vec<String>> {
self.get_specific_feature_names(
sys::GetCatFeatureIndices,
"GetCatFeatureIndices returned null pointer",
)
}

/// Get names of text features used in model
#[cfg(catboost_feature_indices)]
pub fn get_text_feature_names(&self) -> CatBoostResult<Vec<String>> {
self.get_specific_feature_names(
sys::GetTextFeatureIndices,
"GetTextFeatureIndices returned null pointer",
)
}

/// Get names of embedding features used in model
#[cfg(catboost_feature_indices)]
pub fn get_embedding_feature_names(&self) -> CatBoostResult<Vec<String>> {
self.get_specific_feature_names(
sys::GetEmbeddingFeatureIndices,
"GetEmbeddingFeatureIndices returned null pointer",
)
}

/// Get expected float feature count for model
pub fn get_float_features_count(&self) -> usize {
unsafe { sys::GetFloatFeaturesCount(self.handle) }
Expand Down