Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/k-means_generating_cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,17 @@ fn main() {
// Create a new model with 2 clusters
let mut model = KMeansClassifier::new(2);

println!("Training the model...");
// Train the model
model.train(&samples);
println!("Training the model...");
// Our train function returns a Result<(), E>
model.train(&samples).unwrap();

let centroids = model.centroids().as_ref().unwrap();
println!("Model Centroids:\n{:.3}", centroids);

// Predict the classes and partition into
println!("Classifying the samples...");
let classes = model.predict(&samples);
let classes = model.predict(&samples).unwrap();
let (first, second): (Vec<usize>, Vec<usize>) = classes.data().iter().partition(|&x| *x == 0);

println!("Samples closest to first centroid: {}", first.len());
Expand Down
5 changes: 3 additions & 2 deletions examples/nnet-and_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ fn main() {
let mut model = NeuralNet::new(layers, criterion, StochasticGD::default());

println!("Training...");
model.train(&inputs, &targets);
// Our train function returns a Result<(), E>
model.train(&inputs, &targets).unwrap();

let test_cases = vec![
0.0, 0.0,
Expand All @@ -59,7 +60,7 @@ fn main() {
0.0,
];
let test_inputs = Matrix::new(test_cases.len() / 2, 2, test_cases);
let res = model.predict(&test_inputs);
let res = model.predict(&test_inputs).unwrap();

println!("Evaluation...");
let mut hits = 0;
Expand Down
5 changes: 3 additions & 2 deletions examples/svm-sign_learner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ fn main() {

// Trainee
let mut svm_mod = SVM::new(HyperTan::new(100., 0.), 0.3);
svm_mod.train(&inputs, &targets);
// Our train function returns a Result<(), E>
svm_mod.train(&inputs, &targets).unwrap();

println!("Evaluation...");
let mut hits = 0;
Expand All @@ -41,7 +42,7 @@ fn main() {
for n in (-1000..1000).filter(|&x| x % 100 == 0) {
let nf = n as f64;
let input = Matrix::new(1, 1, vec![nf]);
let out = svm_mod.predict(&input);
let out = svm_mod.predict(&input).unwrap();
let res = if out[0] * nf > 0. {
hits += 1;
true
Expand Down
20 changes: 12 additions & 8 deletions src/learning/dbscan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
//! -2.2, 3.1]);
//!
//! let mut model = DBSCAN::new(0.5, 2);
//! model.train(&inputs);
//! model.train(&inputs).unwrap();
//!
//! let clustering = model.clusters().unwrap();
//! ```

use learning::UnSupModel;
use learning::{LearningResult, UnSupModel};
use learning::error::{Error, ErrorKind};

use linalg::{Matrix, Vector};
use rulinalg::utils;
Expand Down Expand Up @@ -75,7 +76,7 @@ impl Default for DBSCAN {

impl UnSupModel<Matrix<f64>, Vector<Option<usize>>> for DBSCAN {
/// Train the classifier using input data.
fn train(&mut self, inputs: &Matrix<f64>) {
fn train(&mut self, inputs: &Matrix<f64>) -> LearningResult<()> {
self.init_params(inputs.rows());
let mut cluster = 0;

Expand All @@ -95,11 +96,13 @@ impl UnSupModel<Matrix<f64>, Vector<Option<usize>>> for DBSCAN {
}

if self.predictive {
self._cluster_data = Some(inputs.clone())
self._cluster_data = Some(inputs.clone());
}

Ok(())
}

fn predict(&self, inputs: &Matrix<f64>) -> Vector<Option<usize>> {
fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<Option<usize>>> {
if self.predictive {
if let (&Some(ref cluster_data), &Some(ref clusters)) = (&self._cluster_data,
&self.clusters) {
Expand All @@ -122,12 +125,13 @@ impl UnSupModel<Matrix<f64>, Vector<Option<usize>>> for DBSCAN {
}
}

Vector::new(classes)
Ok(Vector::new(classes))
} else {
panic!("The model has not been trained.");
Err(Error::new_untrained())
}
} else {
panic!("Model must be set to predictive. Use `self.set_predictive(true)`.");
Err(Error::new(ErrorKind::InvalidState,
"Model must be set to predictive. Use `self.set_predictive(true)`."))
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions src/learning/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub enum ErrorKind {
InvalidData,
/// The action could not be carried out as the model was in an invalid state.
InvalidState,
/// The model has not been trained
UntrainedModel
}

impl Error {
Expand All @@ -38,6 +40,13 @@ impl Error {
}
}

/// Returns a new error for an untrained model
///
/// This function is unstable and may be removed with changes to the API.
pub fn new_untrained() -> Error {
Error::new(ErrorKind::UntrainedModel, "The model has not been trained.")
}

/// Get the kind of this `Error`.
pub fn kind(&self) -> &ErrorKind {
&self.kind
Expand Down
22 changes: 13 additions & 9 deletions src/learning/glm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
//! let mut log_mod = GenLinearModel::new(Bernoulli);
//!
//! // Train the model
//! log_mod.train(&inputs, &targets);
//! log_mod.train(&inputs, &targets).unwrap();
//!
//! // Now we'll predict a new point
//! let new_point = Matrix::new(1,1,vec![10.]);
//! let output = log_mod.predict(&new_point);
//! let output = log_mod.predict(&new_point).unwrap();
//!
//! // Hopefully we classified our new point correctly!
//! assert!(output[0] > 0.5, "Our classifier isn't very good!");
Expand All @@ -38,7 +38,8 @@
use linalg::Vector;
use linalg::Matrix;

use learning::SupModel;
use learning::{LearningResult, SupModel};
use learning::error::{Error, ErrorKind};

/// The Generalized Linear Model
///
Expand Down Expand Up @@ -78,22 +79,24 @@ impl<C: Criterion> GenLinearModel<C> {
/// The model is trained using Iteratively Re-weighted Least Squares.
impl<C: Criterion> SupModel<Matrix<f64>, Vector<f64>> for GenLinearModel<C> {
/// Predict output from inputs.
fn predict(&self, inputs: &Matrix<f64>) -> Vector<f64> {
fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<f64>> {
if let Some(ref v) = self.parameters {
let ones = Matrix::<f64>::ones(inputs.rows(), 1);
let full_inputs = ones.hcat(inputs);
self.criterion.apply_link_inv(full_inputs * v)
Ok(self.criterion.apply_link_inv(full_inputs * v))
} else {
panic!("The model has not been trained.");
Err(Error::new_untrained())
}
}

/// Train the model using inputs and targets.
fn train(&mut self, inputs: &Matrix<f64>, targets: &Vector<f64>) {
fn train(&mut self, inputs: &Matrix<f64>, targets: &Vector<f64>) -> LearningResult<()> {
let n = inputs.rows();

assert!(n == targets.size(),
"Training data do not have the same dimensions.");
if n != targets.size() {
return Err(Error::new(ErrorKind::InvalidData,
"Training data do not have the same dimensions"));
}

// Construct initial estimate for mu
let mut mu = Vector::new(self.criterion.initialize_mu(targets.data()));
Expand Down Expand Up @@ -132,6 +135,7 @@ impl<C: Criterion> SupModel<Matrix<f64>, Vector<f64>> for GenLinearModel<C> {
}

self.parameters = Some(beta);
Ok(())
}
}

Expand Down
68 changes: 36 additions & 32 deletions src/learning/gmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
//! model.cov_option = CovOption::Diagonal;
//!
//! // Where inputs is a Matrix with features in columns.
//! model.train(&inputs);
//! model.train(&inputs).unwrap();
//!
//! // Print the means and covariances of the GMM
//! println!("{:?}", model.means());
//! println!("{:?}", model.covariances());
//!
//! // Where test_inputs is a Matrix with features in columns.
//! let post_probs = model.predict(&test_inputs);
//! let post_probs = model.predict(&test_inputs).unwrap();
//!
//! // Probabilities that each point comes from each Gaussian.
//! println!("{:?}", post_probs.data());
Expand All @@ -34,8 +34,9 @@
use linalg::{Matrix, MatrixSlice, Vector};
use rulinalg::utils;

use learning::UnSupModel;
use learning::{LearningResult, UnSupModel};
use learning::toolkit::rand_utils;
use learning::error::{Error, ErrorKind};

/// Covariance options for GMMs.
///
Expand Down Expand Up @@ -68,7 +69,7 @@ pub struct GaussianMixtureModel {

impl UnSupModel<Matrix<f64>, Matrix<f64>> for GaussianMixtureModel {
/// Train the model using inputs.
fn train(&mut self, inputs: &Matrix<f64>) {
fn train(&mut self, inputs: &Matrix<f64>) -> LearningResult<()> {
// Initialization:
let k = self.comp_count;

Expand Down Expand Up @@ -98,14 +99,16 @@ impl UnSupModel<Matrix<f64>, Matrix<f64>> for GaussianMixtureModel {

self.update_params(inputs, weights);
}

Ok(())
}

/// Predict output from inputs.
fn predict(&self, inputs: &Matrix<f64>) -> Matrix<f64> {
fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
if let (&Some(_), &Some(_)) = (&self.model_means, &self.model_covars) {
self.membership_weights(inputs).0
Ok(self.membership_weights(inputs).0)
} else {
panic!("Model has not been trained.");
Err(Error::new_untrained())
}

}
Expand Down Expand Up @@ -148,32 +151,33 @@ impl GaussianMixtureModel {
///
/// let mix_weights = Vector::new(vec![0.25, 0.25, 0.5]);
///
/// let _ = GaussianMixtureModel::with_weights(3, mix_weights);
/// let gmm = GaussianMixtureModel::with_weights(3, mix_weights).unwrap();
/// ```
///
/// # Panics
/// # Failures
///
/// Panics if either of the following conditions are met:
/// Fails if either of the following conditions are met:
///
/// - Mixture weights do not have length k.
/// - Mixture weights have a negative entry.
pub fn with_weights(k: usize, mixture_weights: Vector<f64>) -> GaussianMixtureModel {
assert!(mixture_weights.size() == k,
"Mixture weights must have length k.");
assert!(!mixture_weights.data().iter().any(|&x| x < 0f64),
"Mixture weights must have only non-negative entries.");

let sum = mixture_weights.sum();
let normalized_weights = mixture_weights / sum;

GaussianMixtureModel {
comp_count: k,
mix_weights: normalized_weights,
model_means: None,
model_covars: None,
log_lik: 0f64,
max_iters: 100,
cov_option: CovOption::Full,
pub fn with_weights(k: usize, mixture_weights: Vector<f64>) -> LearningResult<GaussianMixtureModel> {
if mixture_weights.size() != k {
Err(Error::new(ErrorKind::InvalidParameters, "Mixture weights must have length k."))
} else if mixture_weights.data().iter().any(|&x| x < 0f64) {
Err(Error::new(ErrorKind::InvalidParameters, "Mixture weights must have only non-negative entries."))
} else {
let sum = mixture_weights.sum();
let normalized_weights = mixture_weights / sum;

Ok(GaussianMixtureModel {
comp_count: k,
mix_weights: normalized_weights,
model_means: None,
model_covars: None,
log_lik: 0f64,
max_iters: 100,
cov_option: CovOption::Full,
})
}
}

Expand Down Expand Up @@ -292,7 +296,7 @@ impl GaussianMixtureModel {

for i in 0..n {
let inputs_i = MatrixSlice::from_matrix(inputs, [i, 0], 1, d);
let diff = inputs_i - new_means_k;
let diff = inputs_i - new_means_k;
cov_mat += self.compute_cov(diff, membership_weights[[i, k]]);
}
new_covs.push(cov_mat / sum_weights[k]);
Expand Down Expand Up @@ -332,16 +336,16 @@ mod tests {
}

#[test]
#[should_panic]
fn test_negative_mixtures() {
let mix_weights = Vector::new(vec![-0.25, 0.75, 0.5]);
let _ = GaussianMixtureModel::with_weights(3, mix_weights);
let gmm_res = GaussianMixtureModel::with_weights(3, mix_weights);
assert!(gmm_res.is_err());
}

#[test]
#[should_panic]
fn test_wrong_length_mixtures() {
let mix_weights = Vector::new(vec![0.1, 0.25, 0.75, 0.5]);
let _ = GaussianMixtureModel::with_weights(3, mix_weights);
let gmm_res = GaussianMixtureModel::with_weights(3, mix_weights);
assert!(gmm_res.is_err());
}
}
Loading