diff --git a/.travis.yml b/.travis.yml index be969131..100255b2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,3 +11,5 @@ script: - cargo test --verbose - cargo build --features stats - cargo test --features stats + - cargo build --features datasets + - cargo test --features datasets diff --git a/Cargo.toml b/Cargo.toml index 256cc3cd..8e6adaa3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ license = "MIT" [features] stats = [] +datasets = [] [dependencies] num = { version = "0.1.35", default-features = false } diff --git a/src/datasets/iris.rs b/src/datasets/iris.rs new file mode 100644 index 00000000..7fa4d91c --- /dev/null +++ b/src/datasets/iris.rs @@ -0,0 +1,191 @@ +use rulinalg::matrix::Matrix; +use rulinalg::vector::Vector; + +use super::Dataset; + +/// Load iris dataset. +/// +/// The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant. +/// +/// ## Attribute Information +/// +/// ### Data +/// +/// ``Matrix`` contains following columns. +/// +/// - sepal length in cm +/// - sepal width in cm +/// - petal length in cm +/// - petal width in cm +/// +/// ### Target +/// +/// ``Vector`` contains numbers corresponding to iris species: +/// +/// - ``0``: Iris Setosa +/// - ``1``: Iris Versicolour +/// - ``2``: Iris Virginica +/// +/// Lichman, M. (2013). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. +/// Irvine, CA: University of California, School of Information and Computer Science. +pub fn load() -> Dataset, Vector> { + let data: Matrix = matrix![5.1, 3.5, 1.4, 0.2; + 4.9, 3.0, 1.4, 0.2; + 4.7, 3.2, 1.3, 0.2; + 4.6, 3.1, 1.5, 0.2; + 5.0, 3.6, 1.4, 0.2; + 5.4, 3.9, 1.7, 0.4; + 4.6, 3.4, 1.4, 0.3; + 5.0, 3.4, 1.5, 0.2; + 4.4, 2.9, 1.4, 0.2; + 4.9, 3.1, 1.5, 0.1; + 5.4, 3.7, 1.5, 0.2; + 4.8, 3.4, 1.6, 0.2; + 4.8, 3.0, 1.4, 0.1; + 4.3, 3.0, 1.1, 0.1; + 5.8, 4.0, 1.2, 0.2; + 5.7, 4.4, 1.5, 0.4; + 5.4, 3.9, 1.3, 0.4; + 5.1, 3.5, 1.4, 0.3; + 5.7, 3.8, 1.7, 0.3; + 5.1, 3.8, 1.5, 0.3; + 5.4, 3.4, 1.7, 0.2; + 5.1, 3.7, 1.5, 0.4; + 4.6, 3.6, 1.0, 0.2; + 5.1, 3.3, 1.7, 0.5; + 4.8, 3.4, 1.9, 0.2; + 5.0, 3.0, 1.6, 0.2; + 5.0, 3.4, 1.6, 0.4; + 5.2, 3.5, 1.5, 0.2; + 5.2, 3.4, 1.4, 0.2; + 4.7, 3.2, 1.6, 0.2; + 4.8, 3.1, 1.6, 0.2; + 5.4, 3.4, 1.5, 0.4; + 5.2, 4.1, 1.5, 0.1; + 5.5, 4.2, 1.4, 0.2; + 4.9, 3.1, 1.5, 0.1; + 5.0, 3.2, 1.2, 0.2; + 5.5, 3.5, 1.3, 0.2; + 4.9, 3.1, 1.5, 0.1; + 4.4, 3.0, 1.3, 0.2; + 5.1, 3.4, 1.5, 0.2; + 5.0, 3.5, 1.3, 0.3; + 4.5, 2.3, 1.3, 0.3; + 4.4, 3.2, 1.3, 0.2; + 5.0, 3.5, 1.6, 0.6; + 5.1, 3.8, 1.9, 0.4; + 4.8, 3.0, 1.4, 0.3; + 5.1, 3.8, 1.6, 0.2; + 4.6, 3.2, 1.4, 0.2; + 5.3, 3.7, 1.5, 0.2; + 5.0, 3.3, 1.4, 0.2; + 7.0, 3.2, 4.7, 1.4; + 6.4, 3.2, 4.5, 1.5; + 6.9, 3.1, 4.9, 1.5; + 5.5, 2.3, 4.0, 1.3; + 6.5, 2.8, 4.6, 1.5; + 5.7, 2.8, 4.5, 1.3; + 6.3, 3.3, 4.7, 1.6; + 4.9, 2.4, 3.3, 1.0; + 6.6, 2.9, 4.6, 1.3; + 5.2, 2.7, 3.9, 1.4; + 5.0, 2.0, 3.5, 1.0; + 5.9, 3.0, 4.2, 1.5; + 6.0, 2.2, 4.0, 1.0; + 6.1, 2.9, 4.7, 1.4; + 5.6, 2.9, 3.6, 1.3; + 6.7, 3.1, 4.4, 1.4; + 5.6, 3.0, 4.5, 1.5; + 5.8, 2.7, 4.1, 1.0; + 6.2, 2.2, 4.5, 1.5; + 5.6, 2.5, 3.9, 1.1; + 5.9, 3.2, 4.8, 1.8; + 6.1, 2.8, 4.0, 1.3; + 6.3, 2.5, 4.9, 1.5; + 6.1, 2.8, 4.7, 1.2; + 6.4, 2.9, 4.3, 1.3; + 6.6, 3.0, 4.4, 1.4; + 6.8, 2.8, 4.8, 1.4; + 6.7, 3.0, 5.0, 1.7; + 6.0, 2.9, 4.5, 1.5; + 5.7, 2.6, 3.5, 1.0; + 5.5, 2.4, 3.8, 1.1; + 5.5, 2.4, 3.7, 1.0; + 5.8, 2.7, 3.9, 1.2; + 6.0, 2.7, 5.1, 1.6; + 5.4, 3.0, 4.5, 1.5; + 6.0, 3.4, 4.5, 1.6; + 6.7, 3.1, 4.7, 1.5; + 6.3, 2.3, 4.4, 1.3; + 5.6, 3.0, 4.1, 1.3; + 5.5, 2.5, 4.0, 1.3; + 5.5, 2.6, 4.4, 1.2; + 6.1, 3.0, 4.6, 1.4; + 5.8, 2.6, 4.0, 1.2; + 5.0, 2.3, 3.3, 1.0; + 5.6, 2.7, 4.2, 1.3; + 5.7, 3.0, 4.2, 1.2; + 5.7, 2.9, 4.2, 1.3; + 6.2, 2.9, 4.3, 1.3; + 5.1, 2.5, 3.0, 1.1; + 5.7, 2.8, 4.1, 1.3; + 6.3, 3.3, 6.0, 2.5; + 5.8, 2.7, 5.1, 1.9; + 7.1, 3.0, 5.9, 2.1; + 6.3, 2.9, 5.6, 1.8; + 6.5, 3.0, 5.8, 2.2; + 7.6, 3.0, 6.6, 2.1; + 4.9, 2.5, 4.5, 1.7; + 7.3, 2.9, 6.3, 1.8; + 6.7, 2.5, 5.8, 1.8; + 7.2, 3.6, 6.1, 2.5; + 6.5, 3.2, 5.1, 2.0; + 6.4, 2.7, 5.3, 1.9; + 6.8, 3.0, 5.5, 2.1; + 5.7, 2.5, 5.0, 2.0; + 5.8, 2.8, 5.1, 2.4; + 6.4, 3.2, 5.3, 2.3; + 6.5, 3.0, 5.5, 1.8; + 7.7, 3.8, 6.7, 2.2; + 7.7, 2.6, 6.9, 2.3; + 6.0, 2.2, 5.0, 1.5; + 6.9, 3.2, 5.7, 2.3; + 5.6, 2.8, 4.9, 2.0; + 7.7, 2.8, 6.7, 2.0; + 6.3, 2.7, 4.9, 1.8; + 6.7, 3.3, 5.7, 2.1; + 7.2, 3.2, 6.0, 1.8; + 6.2, 2.8, 4.8, 1.8; + 6.1, 3.0, 4.9, 1.8; + 6.4, 2.8, 5.6, 2.1; + 7.2, 3.0, 5.8, 1.6; + 7.4, 2.8, 6.1, 1.9; + 7.9, 3.8, 6.4, 2.0; + 6.4, 2.8, 5.6, 2.2; + 6.3, 2.8, 5.1, 1.5; + 6.1, 2.6, 5.6, 1.4; + 7.7, 3.0, 6.1, 2.3; + 6.3, 3.4, 5.6, 2.4; + 6.4, 3.1, 5.5, 1.8; + 6.0, 3.0, 4.8, 1.8; + 6.9, 3.1, 5.4, 2.1; + 6.7, 3.1, 5.6, 2.4; + 6.9, 3.1, 5.1, 2.3; + 5.8, 2.7, 5.1, 1.9; + 6.8, 3.2, 5.9, 2.3; + 6.7, 3.3, 5.7, 2.5; + 6.7, 3.0, 5.2, 2.3; + 6.3, 2.5, 5.0, 1.9; + 6.5, 3.0, 5.2, 2.0; + 6.2, 3.4, 5.4, 2.3; + 5.9, 3.0, 5.1, 1.8]; + let target: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; + + Dataset{ data: data, + target: Vector::new(target) } +} \ No newline at end of file diff --git a/src/datasets/mod.rs b/src/datasets/mod.rs new file mode 100644 index 00000000..701b5514 --- /dev/null +++ b/src/datasets/mod.rs @@ -0,0 +1,25 @@ +use std::fmt::Debug; + +/// Module for iris dataset. +pub mod iris; + +/// Dataset container +#[derive(Clone, Debug)] +pub struct Dataset where D: Clone + Debug, T: Clone + Debug { + + data: D, + target: T +} + +impl Dataset where D: Clone + Debug, T: Clone + Debug { + + /// Returns explanatory variable (features) + pub fn data(&self) -> &D { + &self.data + } + + /// Returns objective variable (target) + pub fn target(&self) -> &T { + &self.target + } +} diff --git a/src/lib.rs b/src/lib.rs index 6f9da5e3..a822f58a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -221,3 +221,7 @@ pub mod analysis { pub mod cross_validation; pub mod score; } + +#[cfg(feature = "datasets")] +/// Module for datasets. +pub mod datasets; diff --git a/tests/datasets.rs b/tests/datasets.rs new file mode 100644 index 00000000..8fa0d93d --- /dev/null +++ b/tests/datasets.rs @@ -0,0 +1,18 @@ +extern crate rusty_machine as rm; + + +#[cfg(datasets)] +mod test { + + use rm::datasets::iris; + use rm::linalg::BaseMatrix; + + #[test] + fn test_iris() { + let dt = iris::load_(); + assert_eq!(dt.data().rows(), 150); + assert_eq!(dt.data().cols(), 4); + + assert_eq!(dt.target().size(), 150); + } +} \ No newline at end of file diff --git a/tests/lib.rs b/tests/lib.rs index f1261809..68abcdbf 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -10,4 +10,7 @@ pub mod learning { pub mod optim { mod grad_desc; } -} \ No newline at end of file +} + +#[cfg(datasets)] +pub mod datasets; \ No newline at end of file