Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/data/transforms/minmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

use learning::error::{Error, ErrorKind};
use linalg::{Matrix, BaseMatrix, BaseMatrixMut};
use super::Transformer;
use super::{Invertible, Transformer};

use rulinalg::utils;

Expand Down Expand Up @@ -145,7 +145,9 @@ impl<T: Float> Transformer<Matrix<T>> for MinMaxScaler<T> {

Ok(inputs)
}
}

impl<T: Float> Invertible<Matrix<T>> for MinMaxScaler<T> {
fn inv_transform(&self, mut inputs: Matrix<T>) -> Result<Matrix<T>, Error> {
if let (&Some(ref scales), &Some(ref consts)) = (&self.scale_factors, &self.const_factors) {

Expand All @@ -171,7 +173,7 @@ impl<T: Float> Transformer<Matrix<T>> for MinMaxScaler<T> {
#[cfg(test)]
mod tests {
use super::*;
use super::super::Transformer;
use super::super::{Transformer, Invertible};
use linalg::Matrix;
use std::f64;

Expand Down
7 changes: 6 additions & 1 deletion src/data/transforms/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! The Transforms module
//!
//! This module contains the `Transformer` trait and reexports
//! This module contains the `Transformer` and `Invertible` traits and reexports
//! the transformers from child modules.
//!
//! The `Transformer` trait provides a shared interface for all of the
Expand All @@ -11,17 +11,22 @@

pub mod minmax;
pub mod standardize;
pub mod shuffle;

use learning::error;

pub use self::minmax::MinMaxScaler;
pub use self::shuffle::Shuffler;
pub use self::standardize::Standardizer;

/// Trait for data transformers
pub trait Transformer<T> {
/// Transforms the inputs and stores the transformation in the Transformer
fn transform(&mut self, inputs: T) -> Result<T, error::Error>;
}

/// Trait for invertible data transformers
pub trait Invertible<T> : Transformer<T> {
/// Maps the inputs using the inverse of the fitted transform.
fn inv_transform(&self, inputs: T) -> Result<T, error::Error>;
}
120 changes: 120 additions & 0 deletions src/data/transforms/shuffle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
//! The Shuffler
//!
//! This module contains the `Shuffler` transformer. `Shuffler` implements the
//! `Transformer` trait and is used to shuffle the rows of an input matrix.
//! You can control the random number generator used by the `Shuffler`.
//!
//! # Examples
//!
//! ```
//! use rusty_machine::linalg::Matrix;
//! use rusty_machine::data::transforms::Transformer;
//! use rusty_machine::data::transforms::shuffle::Shuffler;
//!
//! // Create an input matrix that we want to shuffle
//! let mat = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
//!
//! // Create a new shuffler
//! let mut shuffler = Shuffler::default();
//! let shuffled_mat = shuffler.transform(mat).unwrap();
//!
//! println!("{}", shuffled_mat);
//! ```

use learning::LearningResult;
use linalg::{Matrix, BaseMatrix, BaseMatrixMut};
use super::Transformer;

use rand::{Rng, thread_rng, ThreadRng};

/// The `Shuffler`
///
/// Provides an implementation of `Transformer` which shuffles
/// the input rows in place.
#[derive(Debug)]
pub struct Shuffler<R: Rng> {
rng: R,
}

impl<R: Rng> Shuffler<R> {
/// Construct a new `Shuffler` with given random number generator.
///
/// # Examples
///
/// ```
/// # extern crate rand;
/// # extern crate rusty_machine;
///
/// use rusty_machine::data::transforms::Transformer;
/// use rusty_machine::data::transforms::shuffle::Shuffler;
/// use rand::{StdRng, SeedableRng};
///
/// # fn main() {
/// // We can create a seeded rng
/// let rng = StdRng::from_seed(&[1, 2, 3]);
///
/// let shuffler = Shuffler::new(rng);
/// # }
/// ```
pub fn new(rng: R) -> Self {
Shuffler { rng: rng }
}
}

/// Create a new shuffler using the `rand::thread_rng` function
/// to provide a randomly seeded random number generator.
impl Default for Shuffler<ThreadRng> {
fn default() -> Self {
Shuffler { rng: thread_rng() }
}
}

/// The `Shuffler` will transform the input `Matrix` by shuffling
/// its rows in place.
///
/// Under the hood this uses a Fisher-Yates shuffle.
impl<R: Rng, T> Transformer<Matrix<T>> for Shuffler<R> {
fn transform(&mut self, mut inputs: Matrix<T>) -> LearningResult<Matrix<T>> {
let n = inputs.rows();

for i in 0..n {
// Swap i with a random point after it
let j = self.rng.gen_range(0, n - i);
inputs.swap_rows(i, i + j);
}

Ok(inputs)
}
}

#[cfg(test)]
mod tests {
use linalg::Matrix;
use super::super::Transformer;
use super::Shuffler;

use rand::{StdRng, SeedableRng};

#[test]
fn seeded_shuffle() {
let rng = StdRng::from_seed(&[1, 2, 3]);
let mut shuffler = Shuffler::new(rng);

let mat = Matrix::new(4, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let shuffled = shuffler.transform(mat).unwrap();

assert_eq!(shuffled.into_vec(),
vec![3.0, 4.0, 1.0, 2.0, 7.0, 8.0, 5.0, 6.0]);
}

#[test]
fn shuffle_single_row() {
let mut shuffler = Shuffler::default();

let mat = Matrix::new(1, 8, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let shuffled = shuffler.transform(mat).unwrap();

assert_eq!(shuffled.into_vec(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
}
}
6 changes: 4 additions & 2 deletions src/data/transforms/standardize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

use learning::error::{Error, ErrorKind};
use linalg::{Matrix, Vector, Axes, BaseMatrix, BaseMatrixMut};
use super::Transformer;
use super::{Invertible, Transformer};

use rulinalg::utils;

Expand Down Expand Up @@ -114,7 +114,9 @@ impl<T: Float + FromPrimitive> Transformer<Matrix<T>> for Standardizer<T> {
Ok(inputs)
}
}
}

impl<T: Float + FromPrimitive> Invertible<Matrix<T>> for Standardizer<T> {
fn inv_transform(&self, mut inputs: Matrix<T>) -> Result<Matrix<T>, Error> {
if let (&Some(ref means), &Some(ref variances)) = (&self.means, &self.variances) {

Expand Down Expand Up @@ -143,7 +145,7 @@ impl<T: Float + FromPrimitive> Transformer<Matrix<T>> for Standardizer<T> {
#[cfg(test)]
mod tests {
use super::*;
use super::super::Transformer;
use super::super::{Transformer, Invertible};
use linalg::{Axes, Matrix};

use std::f64;
Expand Down