diff --git a/src/data/transforms/minmax.rs b/src/data/transforms/minmax.rs index 17c38c64..580d0ea3 100644 --- a/src/data/transforms/minmax.rs +++ b/src/data/transforms/minmax.rs @@ -25,7 +25,7 @@ //! ``` use learning::error::{Error, ErrorKind}; -use linalg::{Matrix, BaseMatrix, BaseMatrixMut}; +use linalg::{Matrix, BaseMatrix, BaseMatrixMut, Vector}; use super::{Invertible, Transformer}; use rulinalg::utils; @@ -42,9 +42,9 @@ use libnum::Float; #[derive(Debug)] pub struct MinMaxScaler { /// Values to scale each column by - scale_factors: Option>, + scale_factors: Option>, /// Values to add to each column after scaling - const_factors: Option>, + const_factors: Option>, /// The min of the new data (default 0) scaled_min: T, /// The max of the new data (default 1) @@ -82,9 +82,12 @@ impl MinMaxScaler { } impl Transformer> for MinMaxScaler { - fn transform(&mut self, mut inputs: Matrix) -> Result, Error> { + + fn fit(&mut self, inputs: &Matrix) -> Result<(), Error> { let features = inputs.cols(); + // ToDo: can use min, max + // https://github.com/AtheMathmo/rulinalg/pull/115 let mut input_min_max = vec![(T::max_value(), T::min_value()); features]; for row in inputs.iter_rows() { @@ -95,18 +98,14 @@ impl Transformer> for MinMaxScaler { processed", idx))); } - // Update min if *feature < min_max.0 { min_max.0 = *feature; } - // Update max if *feature > min_max.1 { min_max.1 = *feature; } - - } } @@ -130,28 +129,46 @@ impl Transformer> for MinMaxScaler { .map(|(&(_, x), &s)| self.scaled_max - x * s) .collect::>(); - for row in inputs.iter_rows_mut() { - utils::in_place_vec_bin_op(row, &scales, |x, &y| { - *x = *x * y; - }); + self.scale_factors = Some(Vector::new(scales)); + self.const_factors = Some(Vector::new(consts)); + Ok(()) + } - utils::in_place_vec_bin_op(row, &consts, |x, &y| { - *x = *x + y; - }); + fn transform(&mut self, mut inputs: Matrix) -> Result, Error> { + if let (&None, &None) = (&self.scale_factors, &self.const_factors) { + // if Transformer is not fitted to the data, fit for backward-compat. + try!(self.fit(&inputs)); } - self.scale_factors = Some(scales); - self.const_factors = Some(consts); - - Ok(inputs) + if let (&Some(ref scales), &Some(ref consts)) = (&self.scale_factors, &self.const_factors) { + if scales.size() != inputs.cols() { + Err(Error::new(ErrorKind::InvalidData, + "Input data has different number of columns from fitted data.")) + } else { + for row in inputs.iter_rows_mut() { + utils::in_place_vec_bin_op(row, scales.data(), |x, &y| { + *x = *x * y; + }); + + utils::in_place_vec_bin_op(row, consts.data(), |x, &y| { + *x = *x + y; + }); + } + Ok(inputs) + } + } else { + // can't happen + Err(Error::new(ErrorKind::InvalidState, "Transformer has not been fitted.")) + } } } impl Invertible> for MinMaxScaler { + fn inv_transform(&self, mut inputs: Matrix) -> Result, Error> { if let (&Some(ref scales), &Some(ref consts)) = (&self.scale_factors, &self.const_factors) { - let features = scales.len(); + let features = scales.size(); if inputs.cols() != features { return Err(Error::new(ErrorKind::InvalidData, "Inputs have different feature count than transformer.")); diff --git a/src/data/transforms/mod.rs b/src/data/transforms/mod.rs index 039aba7e..e27dfc2d 100644 --- a/src/data/transforms/mod.rs +++ b/src/data/transforms/mod.rs @@ -6,7 +6,7 @@ //! The `Transformer` trait provides a shared interface for all of the //! data preprocessing transformations in rusty-machine. //! -//! The transformers provide preprocessing transformations which are +//! The transformers provide preprocessing transformations which are //! commonly used in machine learning. pub mod minmax; @@ -21,6 +21,8 @@ pub use self::standardize::Standardizer; /// Trait for data transformers pub trait Transformer { + /// Fit Transformer to input data, and stores the transformation in the Transformer + fn fit(&mut self, inputs: &T) -> Result<(), error::Error>; /// Transforms the inputs and stores the transformation in the Transformer fn transform(&mut self, inputs: T) -> Result; } diff --git a/src/data/transforms/shuffle.rs b/src/data/transforms/shuffle.rs index b20fc0bf..81112ef4 100644 --- a/src/data/transforms/shuffle.rs +++ b/src/data/transforms/shuffle.rs @@ -22,6 +22,7 @@ //! ``` use learning::LearningResult; +use learning::error::Error; use linalg::{Matrix, BaseMatrix, BaseMatrixMut}; use super::Transformer; @@ -74,6 +75,12 @@ impl Default for Shuffler { /// /// Under the hood this uses a Fisher-Yates shuffle. impl Transformer> for Shuffler { + + #[allow(unused_variables)] + fn fit(&mut self, inputs: &Matrix) -> Result<(), Error> { + Ok(()) + } + fn transform(&mut self, mut inputs: Matrix) -> LearningResult> { let n = inputs.rows(); @@ -117,4 +124,16 @@ mod tests { assert_eq!(shuffled.into_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); } + + #[test] + fn shuffle_fit() { + let rng = StdRng::from_seed(&[1, 2, 3]); + let mut shuffler = Shuffler::new(rng); + + // no op + let mat = Matrix::new(4, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + let res = shuffler.fit(&mat).unwrap(); + + assert_eq!(res, ()); + } } \ No newline at end of file diff --git a/src/data/transforms/standardize.rs b/src/data/transforms/standardize.rs index a790ef49..b13fa807 100644 --- a/src/data/transforms/standardize.rs +++ b/src/data/transforms/standardize.rs @@ -87,7 +87,8 @@ impl Standardizer { } impl Transformer> for Standardizer { - fn transform(&mut self, mut inputs: Matrix) -> Result, Error> { + + fn fit(&mut self, inputs: &Matrix) -> Result<(), Error> { if inputs.rows() <= 1 { Err(Error::new(ErrorKind::InvalidData, "Cannot standardize data with only one row.")) @@ -100,18 +101,34 @@ impl Transformer> for Standardizer { if mean.data().iter().any(|x| !x.is_finite()) { return Err(Error::new(ErrorKind::InvalidData, "Some data point is non-finite.")); } - - for row in inputs.iter_rows_mut() { - // Subtract the mean - utils::in_place_vec_bin_op(row, &mean.data(), |x, &y| *x = *x - y); - utils::in_place_vec_bin_op(row, &variance.data(), |x, &y| { - *x = (*x * self.scaled_stdev / y.sqrt()) + self.scaled_mean - }); - } - self.means = Some(mean); self.variances = Some(variance); - Ok(inputs) + Ok(()) + } + } + + fn transform(&mut self, mut inputs: Matrix) -> Result, Error> { + if let (&None, &None) = (&self.means, &self.variances) { + // if Transformer is not fitted to the data, fit for backward-compat. + try!(self.fit(&inputs)); + } + + if let (&Some(ref means), &Some(ref variances)) = (&self.means, &self.variances) { + if means.size() != inputs.cols() { + Err(Error::new(ErrorKind::InvalidData, + "Input data has different number of columns from fitted data.")) + } else { + for row in inputs.iter_rows_mut() { + // Subtract the mean + utils::in_place_vec_bin_op(row, means.data(), |x, &y| *x = *x - y); + utils::in_place_vec_bin_op(row, variances.data(), |x, &y| { + *x = (*x * self.scaled_stdev / y.sqrt()) + self.scaled_mean + }); + } + Ok(inputs) + } + } else { + Err(Error::new(ErrorKind::InvalidState, "Transformer has not been fitted.")) } } }