diff --git a/src/data/transforms/minmax.rs b/src/data/transforms/minmax.rs index 81faa186..17c38c64 100644 --- a/src/data/transforms/minmax.rs +++ b/src/data/transforms/minmax.rs @@ -26,7 +26,7 @@ use learning::error::{Error, ErrorKind}; use linalg::{Matrix, BaseMatrix, BaseMatrixMut}; -use super::Transformer; +use super::{Invertible, Transformer}; use rulinalg::utils; @@ -145,7 +145,9 @@ impl Transformer> for MinMaxScaler { Ok(inputs) } +} +impl Invertible> for MinMaxScaler { fn inv_transform(&self, mut inputs: Matrix) -> Result, Error> { if let (&Some(ref scales), &Some(ref consts)) = (&self.scale_factors, &self.const_factors) { @@ -171,7 +173,7 @@ impl Transformer> for MinMaxScaler { #[cfg(test)] mod tests { use super::*; - use super::super::Transformer; + use super::super::{Transformer, Invertible}; use linalg::Matrix; use std::f64; diff --git a/src/data/transforms/mod.rs b/src/data/transforms/mod.rs index 872b3863..039aba7e 100644 --- a/src/data/transforms/mod.rs +++ b/src/data/transforms/mod.rs @@ -1,6 +1,6 @@ //! The Transforms module //! -//! This module contains the `Transformer` trait and reexports +//! This module contains the `Transformer` and `Invertible` traits and reexports //! the transformers from child modules. //! //! The `Transformer` trait provides a shared interface for all of the @@ -11,17 +11,22 @@ pub mod minmax; pub mod standardize; +pub mod shuffle; use learning::error; pub use self::minmax::MinMaxScaler; +pub use self::shuffle::Shuffler; pub use self::standardize::Standardizer; /// Trait for data transformers pub trait Transformer { /// Transforms the inputs and stores the transformation in the Transformer fn transform(&mut self, inputs: T) -> Result; +} +/// Trait for invertible data transformers +pub trait Invertible : Transformer { /// Maps the inputs using the inverse of the fitted transform. fn inv_transform(&self, inputs: T) -> Result; } \ No newline at end of file diff --git a/src/data/transforms/shuffle.rs b/src/data/transforms/shuffle.rs new file mode 100644 index 00000000..b20fc0bf --- /dev/null +++ b/src/data/transforms/shuffle.rs @@ -0,0 +1,120 @@ +//! The Shuffler +//! +//! This module contains the `Shuffler` transformer. `Shuffler` implements the +//! `Transformer` trait and is used to shuffle the rows of an input matrix. +//! You can control the random number generator used by the `Shuffler`. +//! +//! # Examples +//! +//! ``` +//! use rusty_machine::linalg::Matrix; +//! use rusty_machine::data::transforms::Transformer; +//! use rusty_machine::data::transforms::shuffle::Shuffler; +//! +//! // Create an input matrix that we want to shuffle +//! let mat = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); +//! +//! // Create a new shuffler +//! let mut shuffler = Shuffler::default(); +//! let shuffled_mat = shuffler.transform(mat).unwrap(); +//! +//! println!("{}", shuffled_mat); +//! ``` + +use learning::LearningResult; +use linalg::{Matrix, BaseMatrix, BaseMatrixMut}; +use super::Transformer; + +use rand::{Rng, thread_rng, ThreadRng}; + +/// The `Shuffler` +/// +/// Provides an implementation of `Transformer` which shuffles +/// the input rows in place. +#[derive(Debug)] +pub struct Shuffler { + rng: R, +} + +impl Shuffler { + /// Construct a new `Shuffler` with given random number generator. + /// + /// # Examples + /// + /// ``` + /// # extern crate rand; + /// # extern crate rusty_machine; + /// + /// use rusty_machine::data::transforms::Transformer; + /// use rusty_machine::data::transforms::shuffle::Shuffler; + /// use rand::{StdRng, SeedableRng}; + /// + /// # fn main() { + /// // We can create a seeded rng + /// let rng = StdRng::from_seed(&[1, 2, 3]); + /// + /// let shuffler = Shuffler::new(rng); + /// # } + /// ``` + pub fn new(rng: R) -> Self { + Shuffler { rng: rng } + } +} + +/// Create a new shuffler using the `rand::thread_rng` function +/// to provide a randomly seeded random number generator. +impl Default for Shuffler { + fn default() -> Self { + Shuffler { rng: thread_rng() } + } +} + +/// The `Shuffler` will transform the input `Matrix` by shuffling +/// its rows in place. +/// +/// Under the hood this uses a Fisher-Yates shuffle. +impl Transformer> for Shuffler { + fn transform(&mut self, mut inputs: Matrix) -> LearningResult> { + let n = inputs.rows(); + + for i in 0..n { + // Swap i with a random point after it + let j = self.rng.gen_range(0, n - i); + inputs.swap_rows(i, i + j); + } + + Ok(inputs) + } +} + +#[cfg(test)] +mod tests { + use linalg::Matrix; + use super::super::Transformer; + use super::Shuffler; + + use rand::{StdRng, SeedableRng}; + + #[test] + fn seeded_shuffle() { + let rng = StdRng::from_seed(&[1, 2, 3]); + let mut shuffler = Shuffler::new(rng); + + let mat = Matrix::new(4, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + let shuffled = shuffler.transform(mat).unwrap(); + + assert_eq!(shuffled.into_vec(), + vec![3.0, 4.0, 1.0, 2.0, 7.0, 8.0, 5.0, 6.0]); + } + + #[test] + fn shuffle_single_row() { + let mut shuffler = Shuffler::default(); + + let mat = Matrix::new(1, 8, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + let shuffled = shuffler.transform(mat).unwrap(); + + assert_eq!(shuffled.into_vec(), + vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + } +} \ No newline at end of file diff --git a/src/data/transforms/standardize.rs b/src/data/transforms/standardize.rs index 141c1f23..a790ef49 100644 --- a/src/data/transforms/standardize.rs +++ b/src/data/transforms/standardize.rs @@ -26,7 +26,7 @@ use learning::error::{Error, ErrorKind}; use linalg::{Matrix, Vector, Axes, BaseMatrix, BaseMatrixMut}; -use super::Transformer; +use super::{Invertible, Transformer}; use rulinalg::utils; @@ -114,7 +114,9 @@ impl Transformer> for Standardizer { Ok(inputs) } } +} +impl Invertible> for Standardizer { fn inv_transform(&self, mut inputs: Matrix) -> Result, Error> { if let (&Some(ref means), &Some(ref variances)) = (&self.means, &self.variances) { @@ -143,7 +145,7 @@ impl Transformer> for Standardizer { #[cfg(test)] mod tests { use super::*; - use super::super::Transformer; + use super::super::{Transformer, Invertible}; use linalg::{Axes, Matrix}; use std::f64;