From d2367df461af2f1148c6bbe0c3422469bceb9b8e Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sun, 12 May 2024 11:11:37 -0500 Subject: [PATCH 01/23] update Signed-off-by: Joe McCain III --- Cargo.toml | 2 +- core/Cargo.toml | 4 + core/src/macros.rs | 69 ++++++++++++---- core/src/nn/models/model.rs | 21 ++++- core/src/nn/models/module.rs | 10 ++- core/src/traits/generator.rs | 11 +++ core/src/traits/misc/setup.rs | 53 +++++++++++++ core/src/traits/mod.rs | 20 ++--- core/src/traits/ops.rs | 47 +++++++++++ models/linear/src/model/config.rs | 7 ++ models/transformers/src/attention/head.rs | 29 ++++--- models/transformers/src/codec/decoder.rs | 13 +++ models/transformers/src/codec/encoder.rs | 13 +++ models/transformers/src/codec/mod.rs | 25 ++++++ models/transformers/src/codec/model.rs | 68 ++++++++++++++++ models/transformers/src/impls/impl_head.rs | 4 +- models/transformers/src/impls/impl_linalg.rs | 16 ++-- models/transformers/src/impls/impl_params.rs | 12 +-- models/transformers/src/lib.rs | 4 +- models/transformers/src/macros.rs | 61 +++++++++----- models/transformers/src/params/item.rs | 79 +++++++++++++++++++ models/transformers/src/params/mod.rs | 19 +++-- .../src/params/{qkv.rs => store.rs} | 41 ++++++++-- models/transformers/tests/attention.rs | 6 +- 24 files changed, 533 insertions(+), 101 deletions(-) create mode 100644 core/src/traits/generator.rs create mode 100644 core/src/traits/ops.rs create mode 100644 models/transformers/src/codec/decoder.rs create mode 100644 models/transformers/src/codec/encoder.rs create mode 100644 models/transformers/src/codec/mod.rs create mode 100644 models/transformers/src/codec/model.rs create mode 100644 models/transformers/src/params/item.rs rename models/transformers/src/params/{qkv.rs => store.rs} (55%) diff --git a/Cargo.toml b/Cargo.toml index 87b42c15..0f3ef677 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ version = "0.1.14" [workspace.dependencies] # acme = { features = ["full"], branch = "v0.3.2", git = "https://github.com/FL03/acme", version = "0.3.2" } # ndtensor = { features = ["full"], branch = "v0.1.1", git = "https://github.com/FL03/ndtensor", version = "0.1" } -# scsys = { features = ["full"], branch = "v0.2.2", git = "https://github.com/scattered-systems/scsys", version = "0.2" } +scsys = { default-features = false, branch = "v0.2.3", git = "https://github.com/scattered-systems/scsys.git", version = "0.2" } approx = "0.5" itertools = "0.12" diff --git a/core/Cargo.toml b/core/Cargo.toml index a710e1d0..d112087b 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -30,6 +30,7 @@ alloc = [ "num/alloc", "rand?/alloc", "rand_distr?/alloc", + "scsys/alloc", "serde?/alloc", ] @@ -66,6 +67,7 @@ serde = [ "num/serde", "rand?/serde1", "rand_distr?/serde1", + "scsys/serde", "uuid/serde" ] @@ -83,6 +85,7 @@ std = [ "ndarray/std", "num/std", "rng_std", + "scsys/std", "serde/std", "strum/std", "uuid/std" @@ -111,6 +114,7 @@ required-features = ["approx"] [dependencies] ndarray.workspace = true num.workspace = true +scsys.workspace = true smart-default.workspace = true strum.workspace = true diff --git a/core/src/macros.rs b/core/src/macros.rs index b447975f..7498b0fd 100644 --- a/core/src/macros.rs +++ b/core/src/macros.rs @@ -99,18 +99,59 @@ macro_rules! build_unary_trait { }; } -macro_rules! linspace { - (start: $start:expr, end: $end:expr, n: $n:expr, dtype: $T:ty) => { - ndarray::Array1::<$T>::linspace($start, $end, $n) - }; - (end: $end:expr, dtype: $T:ty) => { - let n = ($end - $T::one()).to_usize().unwrap(); - ndarray::Array1::<$T>::linspace($T::zero(), $end, $end.to_usize().unwrap()) - }; - (dim: $dim:expr, dtype: $T:ty) => {{ - let dim = $dim.into_dimension(); - let n = dim.size(); - ndarray::Array1::<$T>::linspace(<$T>::zero(), <$T>::from(n - 1).unwrap(), n) - .into_shape($dim) - }}; +#[macro_export] +macro_rules! builder { + ($(#[derive($($d:ident),*)])?$name:ident::<$inner:ty> {$($k:ident: $v:ty),* $(,)?}) => { + $crate::builder!(@loop $(#[derive($($d),*)])?$name::<$inner> {$($k: $v),*}); + }; + (@loop #[derive($($d:ident),*)] $name:ident::<$inner:ty> {$($k:ident: $v:ty),* $(,)?}) => { + pub struct $name { + inner: $inner, + } + + impl $name { + pub fn new() -> Self { + Self { inner: Default::default() } + } + + pub fn build(self) -> $inner { + self.inner + } + + $( + pub fn $k(mut self, $k: $v) -> Self { + self.inner.$k = $k; + self + } + )* + } + }; + (@loop $name:ident::<$inner:ty> {$($k:ident: $v:ty),* $(,)?}) => { + pub struct $name { + inner: $inner, + } + + impl $name { + pub fn new() -> Self { + Self { + inner: Default::default() + } + } + + pub fn from_inner(inner: $inner) -> Self { + Self { inner } + } + + pub fn build(self) -> $inner { + self.inner + } + + $( + pub fn $k(mut self, $k: $v) -> Self { + self.inner.$k = $k; + self + } + )* + } + }; } diff --git a/core/src/nn/models/model.rs b/core/src/nn/models/model.rs index 16687d4b..00736aaf 100644 --- a/core/src/nn/models/model.rs +++ b/core/src/nn/models/model.rs @@ -1,6 +1,21 @@ /* - Appellation: model - Contrib: FL03 + Appellation: model + Contrib: FL03 */ +use super::{DynModule, Module}; +use crate::traits::Forward; -pub trait Model {} +pub trait Model: Module +where + Self: Forward, +{ + type Ctx; + type Data; + + fn children(&self) -> Vec>; +} + +pub struct ConfigBase { + pub id: usize, + pub name: String, +} diff --git a/core/src/nn/models/module.rs b/core/src/nn/models/module.rs index af182f25..79fbe8ed 100644 --- a/core/src/nn/models/module.rs +++ b/core/src/nn/models/module.rs @@ -2,13 +2,17 @@ Appellation: modules Contrib: FL03 */ -use crate::Predict; +use crate::{Config, Predict}; + +pub type DynModule = Box>; +pub type DynModuleExt = Box>; +pub type Stack = Vec>>; /// A `Module` defines any object that may be used as a layer in a neural network. /// [Config](Module::Config) is a type that defines the configuration of the module; including any and all hyperparameters. /// [Params](Module::Params) is a type that defines the parameters of the module; typically references a Linear set of parameters { weights, bias } pub trait Module { - type Config; + type Config: Config; type Params; fn config(&self) -> &Self::Config; @@ -20,4 +24,4 @@ pub trait Module { pub trait ModuleExt: Module + Predict {} -pub type Stack = Vec>>; +impl ModuleExt for M where M: Module + Predict {} diff --git a/core/src/traits/generator.rs b/core/src/traits/generator.rs new file mode 100644 index 00000000..54d058fd --- /dev/null +++ b/core/src/traits/generator.rs @@ -0,0 +1,11 @@ +/* + Appellation: generator + Contrib: FL03 +*/ + +/// This trait describes actors that can generate data +pub trait Generative { + type Output; + + fn generate(&self, args: T) -> Self::Output; +} diff --git a/core/src/traits/misc/setup.rs b/core/src/traits/misc/setup.rs index 1b7856f9..d6f2f611 100644 --- a/core/src/traits/misc/setup.rs +++ b/core/src/traits/misc/setup.rs @@ -2,6 +2,26 @@ Appellation: setup Contrib: FL03 */ +use core::borrow::{Borrow, BorrowMut}; + +/// A trait used to denote objects that may be used for configuring various items +pub trait Config {} + +/// [Configuration] describes composite configuration objects; +/// A configuration object is allowed to inherit from another configuration object +pub trait Configuration +where + C: Config, + Self::Config: Borrow, +{ + type Config: Config; + + fn root(&self) -> &C; + + fn set(&mut self, config: Self::Config); + + fn set_root(&mut self, config: C); +} pub trait Init { fn init(self) -> Self; @@ -16,3 +36,36 @@ pub trait Setup { fn setup(&mut self, config: Self::Config); } + +pub trait Context +where + C: Config, +{ + type Cnf: Configuration; + + fn config(&self) -> Self::Cnf; +} + +/* + ************* Implementations ************* +*/ + +impl Configuration for D +where + C: Config, + D: Config + BorrowMut, +{ + type Config = D; + + fn root(&self) -> &C { + self.borrow() + } + + fn set(&mut self, config: Self::Config) { + *self = config; + } + + fn set_root(&mut self, config: C) { + *self.borrow_mut() = config; + } +} diff --git a/core/src/traits/mod.rs b/core/src/traits/mod.rs index 8ee0a14a..6663f98b 100644 --- a/core/src/traits/mod.rs +++ b/core/src/traits/mod.rs @@ -4,7 +4,9 @@ */ pub use self::prelude::*; +pub mod generator; pub mod math; +pub mod ops; pub mod predict; pub mod train; @@ -39,22 +41,12 @@ pub(crate) mod misc { } } -pub trait Transform { - type Output; - - fn transform(&self, args: &T) -> Self::Output; -} - pub(crate) mod prelude { - pub use super::Transform; - + pub use super::arr::prelude::*; + pub use super::generator::*; pub use super::math::*; + pub use super::misc::prelude::*; + pub use super::ops::*; pub use super::predict::*; pub use super::train::*; - - pub use super::arr::prelude::*; - pub use super::misc::prelude::*; } - -#[cfg(test)] -mod tests {} diff --git a/core/src/traits/ops.rs b/core/src/traits/ops.rs new file mode 100644 index 00000000..68dce74b --- /dev/null +++ b/core/src/traits/ops.rs @@ -0,0 +1,47 @@ +/* + Appellation: ops + Contrib: FL03 +*/ +/// A trait for applying a function to a type +pub trait Apply { + type Output; + + fn apply(&self, f: F) -> Self::Output + where + F: Fn(T) -> U; + + fn apply_mut(&mut self, f: F) -> Self::Output + where + F: FnMut(T) -> U; +} + +pub trait ApplyOnce { + type Output; + + fn apply(self, f: F) -> Self::Output + where + F: FnMut(T) -> U; +} + +pub trait Transform { + type Output; + + fn transform(&self, args: &T) -> Self::Output; +} + +/* + ************* Implementations ************* +*/ +impl ApplyOnce for S +where + S: Iterator, +{ + type Output = core::iter::Map; + + fn apply(self, f: F) -> Self::Output + where + F: FnMut(T) -> U, + { + self.map(f) + } +} diff --git a/models/linear/src/model/config.rs b/models/linear/src/model/config.rs index 568cfc25..6f4f40d8 100644 --- a/models/linear/src/model/config.rs +++ b/models/linear/src/model/config.rs @@ -154,6 +154,13 @@ where } } +impl concision::Config for Config +where + D: Dimension, + K: ParamMode, +{ +} + impl Default for Config where D: Dimension, diff --git a/models/transformers/src/attention/head.rs b/models/transformers/src/attention/head.rs index c85523bb..f3bb9faa 100644 --- a/models/transformers/src/attention/head.rs +++ b/models/transformers/src/attention/head.rs @@ -2,7 +2,7 @@ Appellation: head Contrib: FL03 */ -use crate::params::QKVBase; +use crate::params::ParamsBase; use nd::*; pub struct AttentionHead, D = Ix2> @@ -10,7 +10,7 @@ where D: Dimension, S: RawData, { - params: QKVBase, + params: ParamsBase, } impl AttentionHead @@ -18,7 +18,7 @@ where D: Dimension, S: RawData, { - pub fn from_params(params: QKVBase) -> Self { + pub fn from_params(params: ParamsBase) -> Self { Self { params } } @@ -27,19 +27,28 @@ where F: Fn(D) -> ArrayBase, Sh: ShapeBuilder, { - Self::from_params(QKVBase::builder(shape, builder)) + Self::from_params(ParamsBase::builder(shape, builder)) } - pub fn params(&self) -> &QKVBase { + pub fn from_elem(shape: Sh, value: A) -> Self + where + Sh: ShapeBuilder, + A: Clone, + S: DataOwned, + { + Self::from_params(ParamsBase::from_elem(shape, value)) + } + /// Returns a reference to the underlying parameters. + pub fn params(&self) -> &ParamsBase { &self.params } - - pub fn params_mut(&mut self) -> &mut QKVBase { + /// Returns a mutable reference to the underlying parameters. + pub fn params_mut(&mut self) -> &mut ParamsBase { &mut self.params } access!(params::); - fwd_builder!(new.default where A: Default, S: DataOwned); - fwd_builder!(ones.ones where A: Clone + num::One, S: DataOwned); - fwd_builder!(zeros.zeros where A: Clone + num::Zero, S: DataOwned); + ndbuilder!(new.default where A: Default, S: DataOwned); + ndbuilder!(ones where A: Clone + num::One, S: DataOwned); + ndbuilder!(zeros where A: Clone + num::Zero, S: DataOwned); } diff --git a/models/transformers/src/codec/decoder.rs b/models/transformers/src/codec/decoder.rs new file mode 100644 index 00000000..019e5e94 --- /dev/null +++ b/models/transformers/src/codec/decoder.rs @@ -0,0 +1,13 @@ +/* + Appellation: decoder + Contrib: FL03 +*/ + +#[derive(Default)] +pub struct Decoder {} + +impl Decoder { + pub fn new() -> Self { + Self {} + } +} diff --git a/models/transformers/src/codec/encoder.rs b/models/transformers/src/codec/encoder.rs new file mode 100644 index 00000000..ba63c02b --- /dev/null +++ b/models/transformers/src/codec/encoder.rs @@ -0,0 +1,13 @@ +/* + Appellation: encoder + Contrib: FL03 +*/ + +#[derive(Default)] +pub struct Encoder {} + +impl Encoder { + pub fn new() -> Self { + Self {} + } +} diff --git a/models/transformers/src/codec/mod.rs b/models/transformers/src/codec/mod.rs new file mode 100644 index 00000000..3a7e3f77 --- /dev/null +++ b/models/transformers/src/codec/mod.rs @@ -0,0 +1,25 @@ +/* + Appellation: codec + Contrib: FL03 +*/ +pub use self::{decoder::Decoder, encoder::Encoder, model::*}; + +pub(crate) mod model; + +pub mod decoder; +pub mod encoder; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_codec_builder() { + let ctx = Context::new() + .with_src("src".to_string()) + .with_tgt("tgt".to_string()); + let codec = Codec::new().ctx(ctx).build(); + assert_eq!(codec.context().src, "src"); + assert_eq!(codec.context().tgt, "tgt"); + } +} diff --git a/models/transformers/src/codec/model.rs b/models/transformers/src/codec/model.rs new file mode 100644 index 00000000..4c3cc180 --- /dev/null +++ b/models/transformers/src/codec/model.rs @@ -0,0 +1,68 @@ +/* + Appellation: codec + Contrib: FL03 +*/ +use super::{Decoder, Encoder}; + +#[derive(Default)] +pub struct Codec { + ctx: Context, + decoder: Decoder, + encoder: Encoder, +} + +impl Codec { + pub fn new() -> CodecBuilder { + CodecBuilder::new() + } + + pub fn context(&self) -> &Context { + &self.ctx + } + + pub fn decoder(&self) -> &Decoder { + &self.decoder + } + + pub fn encoder(&self) -> &Encoder { + &self.encoder + } +} + +concision::builder!( + #[derive(Default)] + CodecBuilder:: { + ctx: Context, + decoder: Decoder, + encoder: Encoder, + } +); + +#[derive(Default)] +pub struct Generator { + pub dmodel: usize, + pub vocab: Vec, +} + +#[derive(Default)] +pub struct Context { + pub src: String, // source embedding + pub tgt: String, // target embedding +} + +impl Context { + pub fn new() -> Self { + Self { + src: String::new(), + tgt: String::new(), + } + } + + pub fn with_src(self, src: String) -> Self { + Self { src, ..self } + } + + pub fn with_tgt(self, tgt: String) -> Self { + Self { tgt, ..self } + } +} diff --git a/models/transformers/src/impls/impl_head.rs b/models/transformers/src/impls/impl_head.rs index ee5eaf10..ec8e410f 100644 --- a/models/transformers/src/impls/impl_head.rs +++ b/models/transformers/src/impls/impl_head.rs @@ -3,7 +3,7 @@ Contrib: FL03 */ use crate::attention::AttentionHead; -use crate::params::QKVBase; +use crate::params::ParamsBase; use nd::prelude::*; use nd::DataOwned; @@ -14,6 +14,6 @@ where S: DataOwned, { fn default() -> Self { - Self::from_params(QKVBase::default()) + Self::from_params(ParamsBase::default()) } } diff --git a/models/transformers/src/impls/impl_linalg.rs b/models/transformers/src/impls/impl_linalg.rs index 4440ec34..a50ebe63 100644 --- a/models/transformers/src/impls/impl_linalg.rs +++ b/models/transformers/src/impls/impl_linalg.rs @@ -2,12 +2,12 @@ Appellation: impl_linalg Contrib: FL03 */ -use crate::params::{QKVBase, QKV}; +use crate::params::{Params, ParamsBase}; use concision::Matmul; use nd::linalg::Dot; use nd::*; -impl Matmul> for QKVBase +impl Matmul> for ParamsBase where A: LinalgScalar, D: Dimension, @@ -17,10 +17,10 @@ where T: Data, ArrayBase: Dot, Output = Array>, { - type Output = QKV; + type Output = Params; - fn matmul(&self, rhs: &QKVBase) -> Self::Output { - QKVBase { + fn matmul(&self, rhs: &ParamsBase) -> Self::Output { + ParamsBase { q: self.q().dot(rhs.q()), k: self.k().dot(rhs.k()), v: self.v().dot(rhs.v()), @@ -28,7 +28,7 @@ where } } -impl Matmul> for QKVBase +impl Matmul> for ParamsBase where A: LinalgScalar, D: Dimension, @@ -38,10 +38,10 @@ where T: Data, ArrayBase: Dot, Output = Array>, { - type Output = QKV; + type Output = Params; fn matmul(&self, rhs: &ArrayBase) -> Self::Output { - QKVBase { + ParamsBase { q: self.q().dot(rhs), k: self.k().dot(rhs), v: self.v().dot(rhs), diff --git a/models/transformers/src/impls/impl_params.rs b/models/transformers/src/impls/impl_params.rs index ad79e30d..47c9074a 100644 --- a/models/transformers/src/impls/impl_params.rs +++ b/models/transformers/src/impls/impl_params.rs @@ -2,11 +2,11 @@ Appellation: impl_params Contrib: FL03 */ -use crate::params::QKVBase; +use crate::params::ParamsBase; use nd::prelude::*; use nd::{Data, DataOwned, RawDataClone}; -impl Clone for QKVBase +impl Clone for ParamsBase where D: Dimension, S: RawDataClone, @@ -20,14 +20,14 @@ where } } -impl Copy for QKVBase +impl Copy for ParamsBase where D: Copy + Dimension, S: Copy + RawDataClone, { } -impl Default for QKVBase +impl Default for ParamsBase where D: Dimension, S: DataOwned, @@ -42,7 +42,7 @@ where } } -impl PartialEq for QKVBase +impl PartialEq for ParamsBase where A: PartialEq, D: Dimension, @@ -53,7 +53,7 @@ where } } -impl PartialEq> for QKVBase +impl PartialEq> for ParamsBase where A: PartialEq, B: PartialEq, diff --git a/models/transformers/src/lib.rs b/models/transformers/src/lib.rs index e39da023..af17dd30 100644 --- a/models/transformers/src/lib.rs +++ b/models/transformers/src/lib.rs @@ -17,7 +17,7 @@ extern crate concision_core as concision; extern crate ndarray as nd; pub use self::attention::AttentionHead; -pub use self::params::QKV; +pub use self::params::*; pub use self::transformer::Transformer; #[macro_use] @@ -25,6 +25,7 @@ pub(crate) mod macros; pub(crate) mod transformer; pub mod attention; +pub mod codec; pub mod params; pub(crate) mod impls { @@ -35,6 +36,5 @@ pub(crate) mod impls { pub mod prelude { pub use super::attention::prelude::*; - pub use super::params::prelude::*; pub use super::Transformer; } diff --git a/models/transformers/src/macros.rs b/models/transformers/src/macros.rs index d7b352d0..f4712afc 100644 --- a/models/transformers/src/macros.rs +++ b/models/transformers/src/macros.rs @@ -32,8 +32,20 @@ macro_rules! access { }; } -macro_rules! fwd_builder { - ($method:ident.$call:ident where $($rest:tt)*) => { +macro_rules! ndbuilder { + ($method:ident $($rest:tt)*) => { + ndbuilder!(@impl $method $($rest)*); + }; + (@impl $method:ident where $($rest:tt)*) => { + pub fn $method(shape: Sh) -> Self + where + Sh: ndarray::ShapeBuilder, + $($rest)* + { + Self::builder(shape, ArrayBase::$method) + } + }; + (@impl $method:ident.$call:ident where $($rest:tt)*) => { pub fn $method(shape: Sh) -> Self where Sh: ndarray::ShapeBuilder, @@ -52,7 +64,15 @@ macro_rules! param_views { param_views!(@impl $method.$call::$($rest)*); }; (@impl $method:ident.$call:ident::<$view:ident>(self) where $($rest:tt)*) => { - pub fn $method(self) -> QKVBase<$view, D> + pub fn $method(self) -> $crate::params::ParamsBase<$view, D> + where + $($rest)* + { + param_views!(@apply $call(self)) + } + }; + (@impl $method:ident.$call:ident::<$view:ident>(mut self) where $($rest:tt)*) => { + pub fn $method(mut self) -> $crate::params::ParamsBase<$view, D> where $($rest)* { @@ -60,39 +80,42 @@ macro_rules! param_views { } }; (@impl $method:ident.$call:ident::<$view:ident>(&self) where $($rest:tt)*) => { - pub fn $method(&self) -> QKVBase<$view, D> + pub fn $method(&self) -> $crate::params::ParamsBase<$view, D> where $($rest)* { param_views!(@apply $call(self)) } }; - (@impl $method:ident.$call:ident::<'a, $view:ident>(&self) where $($rest:tt)*) => { - pub fn $method(&self) -> QKVBase<$view<&'_ A>, D> + (@impl $method:ident.$call:ident::<$view:ident>(&mut self) where $($rest:tt)*) => { + pub fn $method(&mut self) -> $crate::params::ParamsBase<$view, D> where $($rest)* { param_views!(@apply $call(self)) } }; - (@apply $call:ident($self:expr)) => { - $crate::params::QKVBase { - q: $self.q.$call(), - k: $self.k.$call(), - v: $self.v.$call(), + (@impl $method:ident.$call:ident::<'a, $view:ident>(&self) where $($rest:tt)*) => { + pub fn $method(&self) -> $crate::params::ParamsBase<$view<&'_ A>, D> + where + $($rest)* + { + param_views!(@apply $call(self)) } }; -} - -macro_rules! qkv_builder { - - ($method:ident.$call:ident where $($rest:tt)*) => { - pub fn $method(shape: Sh) -> Self + (@impl $method:ident.$call:ident::<'a, $view:ident>(&mut self) where $($rest:tt)*) => { + pub fn $method(&mut self) -> $crate::params::ParamsBase<$view<&'_ mut A>, D> where - Sh: ndarray::ShapeBuilder, $($rest)* { - Self::builder(shape, ArrayBase::$call) + param_views!(@apply $call(self)) + } + }; + (@apply $call:ident($self:expr)) => { + $crate::params::ParamsBase { + q: $self.q.$call(), + k: $self.k.$call(), + v: $self.v.$call(), } }; } diff --git a/models/transformers/src/params/item.rs b/models/transformers/src/params/item.rs new file mode 100644 index 00000000..67528ab1 --- /dev/null +++ b/models/transformers/src/params/item.rs @@ -0,0 +1,79 @@ +/* + Appellation: kinds + Contrib: FL03 +*/ +use nd::{ArrayBase, Dimension, Ix2, OwnedRepr, RawData}; +use strum::{AsRefStr, EnumCount, EnumDiscriminants, EnumIs, VariantNames}; + +#[derive(AsRefStr, EnumCount, EnumDiscriminants, EnumIs, VariantNames)] +#[strum_discriminants( + derive( + AsRefStr, + EnumCount, + EnumIs, + Hash, + Ord, + PartialOrd, + VariantNames, + strum::Display, + strum::EnumString, + ), + name(QKV), + strum(serialize_all = "lowercase") +)] +#[cfg_attr( + feature = "serde", + strum_discriminants( + derive(serde::Deserialize, serde::Serialize), + serde(rename_all = "lowercase", untagged) + ) +)] +#[strum(serialize_all = "lowercase")] +pub enum Entry, D = Ix2> +where + D: Dimension, + S: RawData, +{ + Q(ArrayBase), + K(ArrayBase), + V(ArrayBase), +} + +impl Entry +where + D: Dimension, + S: RawData, +{ + pub fn from_q(q: ArrayBase) -> Self { + Self::Q(q) + } + + pub fn from_k(k: ArrayBase) -> Self { + Self::K(k) + } + + pub fn from_v(v: ArrayBase) -> Self { + Self::V(v) + } + + pub fn q(&self) -> Option<&ArrayBase> { + match self { + Self::Q(q) => Some(q), + _ => None, + } + } + + pub fn k(&self) -> Option<&ArrayBase> { + match self { + Self::K(k) => Some(k), + _ => None, + } + } + + pub fn v(&self) -> Option<&ArrayBase> { + match self { + Self::V(v) => Some(v), + _ => None, + } + } +} diff --git a/models/transformers/src/params/mod.rs b/models/transformers/src/params/mod.rs index b4baf401..40cf9cc1 100644 --- a/models/transformers/src/params/mod.rs +++ b/models/transformers/src/params/mod.rs @@ -2,9 +2,10 @@ Appellation: params Contrib: FL03 */ -pub use self::qkv::QKVBase; +pub use self::{item::*, store::ParamsBase}; -mod qkv; +pub(crate) mod item; +pub(crate) mod store; macro_rules! params_ty { ($target:ident: [$($name:ident<$(&$lt:lifetime)?$repr:ident>),* $(,)?]) => { @@ -19,14 +20,16 @@ macro_rules! params_ty { } params_ty!( - QKVBase: [ - QKV, - ArcQKV, - ViewQKV<&'a ViewRepr>, + ParamsBase: [ + Params, + ArcParams, + ParamsView<&'a ViewRepr>, ] ); +#[allow(unused_imports)] pub(crate) mod prelude { - pub use super::QKVBase; - pub use super::{ArcQKV, QKV}; + pub use super::item::{Entry, QKV}; + pub use super::store::ParamsBase; + pub use super::{ArcParams, Params}; } diff --git a/models/transformers/src/params/qkv.rs b/models/transformers/src/params/store.rs similarity index 55% rename from models/transformers/src/params/qkv.rs rename to models/transformers/src/params/store.rs index 231f48d9..6fcdca1b 100644 --- a/models/transformers/src/params/qkv.rs +++ b/models/transformers/src/params/store.rs @@ -6,7 +6,7 @@ use nd::*; use num::traits::{One, Zero}; -pub struct QKVBase, D = Ix2> +pub struct ParamsBase, D = Ix2> where D: Dimension, S: RawData, @@ -16,7 +16,7 @@ where pub(crate) v: ArrayBase, } -impl QKVBase +impl ParamsBase where D: Dimension, S: RawData, @@ -34,19 +34,40 @@ where } } + ndbuilder!(new.default where A: Default, S: DataOwned); + ndbuilder!(ones where A: Clone + One, S: DataOwned); + ndbuilder!(zeros where A: Clone + Zero, S: DataOwned); + access!(q, k, v); - qkv_builder!(new.default where A: Default, S: DataOwned); - qkv_builder!(ones.ones where A: Clone + One, S: DataOwned); - qkv_builder!(zeros.zeros where A: Clone + Zero, S: DataOwned); + pub fn from_elem(shape: Sh, value: A) -> Self + where + Sh: ShapeBuilder, + A: Clone, + S: DataOwned, + { + let dim = shape.into_shape().raw_dim().clone(); + Self { + q: ArrayBase::from_elem(dim.clone(), value.clone()), + k: ArrayBase::from_elem(dim.clone(), value.clone()), + v: ArrayBase::from_elem(dim, value), + } + } - pub fn as_views(&self) -> (ArrayView, ArrayView, ArrayView) + pub fn as_qkv(&self) -> (ArrayView, ArrayView, ArrayView) where S: Data, { (self.q.view(), self.k.view(), self.v.view()) } + pub fn into_qkv(self) -> (ArrayBase, ArrayBase, ArrayBase) + where + S: DataOwned, + { + (self.q, self.k, self.v) + } + /// Return the [pattern](ndarray::Dimension::Pattern) of the dimension pub fn dim(&self) -> D::Pattern { self.q.dim() @@ -59,12 +80,16 @@ where pub fn raw_dim(&self) -> D { self.q.raw_dim() } - + /// Returns a slice of the current shape of the parameters. pub fn shape(&self) -> &[usize] { self.q.shape() } - + param_views!(into_owned::(self) where A: Clone, S: Data); param_views!(to_owned::(&self) where A: Clone, S: Data); + + param_views!(into_shared::(self) where A: Clone, S: DataOwned); param_views!(to_shared::(&self) where A: Clone, S: DataShared); + param_views!(view::<'a, ViewRepr>(&self) where S: Data); + param_views!(view_mut::<'a, ViewRepr>(&mut self) where S: DataMut); } diff --git a/models/transformers/tests/attention.rs b/models/transformers/tests/attention.rs index 98aa4c0e..db1efe2a 100644 --- a/models/transformers/tests/attention.rs +++ b/models/transformers/tests/attention.rs @@ -6,14 +6,14 @@ extern crate concision_core as concision; extern crate concision_transformers as transformers; use concision::{linarr, Matmul}; -use transformers::{AttentionHead, QKV}; +use transformers::{AttentionHead, Params}; use ndarray::prelude::*; #[test] fn test_qkv() { let shape = (2048, 10); - let params = QKV::::new(shape); + let params = Params::::new(shape); assert_eq!(params.q(), &Array::default(shape)); } @@ -23,7 +23,7 @@ fn test_qkv_matmul() { // generate some sample data let data = linarr(shape).unwrap(); // initialize the parameters - let params = QKV::::ones(shape); + let params = Params::::ones(shape); // calculate the expected result let exp = Array2::::ones(shape).dot(&data.t()); // calculate the result From 82a86c1216e1e10ed96f73bcf0753f279a92d3cf Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sun, 12 May 2024 11:51:53 -0500 Subject: [PATCH 02/23] update Signed-off-by: Joe McCain III --- models/linear/src/macros.rs | 87 ++++++++++++++++++++++-------- models/linear/src/params/params.rs | 21 ++++---- 2 files changed, 75 insertions(+), 33 deletions(-) diff --git a/models/linear/src/macros.rs b/models/linear/src/macros.rs index 977fa5a3..5bfa0499 100644 --- a/models/linear/src/macros.rs +++ b/models/linear/src/macros.rs @@ -3,29 +3,6 @@ Contrib: FL03 */ -#[allow(unused_macros)] -macro_rules! params { - {$($k:ident: $v:expr),* $(,)?} => { - params!(@new $($k: $v),*); - }; - (@new bias: $b:expr, weights: $w:expr, mode: $mode:ty) => { - $crate::params::ParamsBase { - bias: $b, - weights: $w, - _mode: core::marker::PhantomData::<$mode>, - } - }; - (@new bias: $b:expr, weights: $w:expr) => { - params!(@new bias: $b, weights: $w, mode: $crate::params::mode::Biased); - }; - (@new bias: $b:expr, weights: $w:expr) => { - params!(@new bias: Some($b), weights: $w, mode: $crate::params::mode::Biased); - }; - (@new weights: $w:expr) => { - params!(@new bias: None, weights: $w, mode: $crate::params::mode::Unbiased); - }; -} - macro_rules! impl_param_builder { ($call:ident where $($rest:tt)*) => { impl_param_builder!(@impl $call where $($rest)*); @@ -46,3 +23,67 @@ macro_rules! impl_param_builder { } }; } + +macro_rules! ndview { + ($method:ident::$($rest:tt)*) => { + ndview!(@impl $method.$method::$($rest)*); + }; + ($method:ident.$call:ident::$($rest:tt)*) => { + ndview!(@impl $method.$call::$($rest)*); + }; + (@impl $method:ident.$call:ident::<$view:ident>(self) where $($rest:tt)*) => { + pub fn $method(self) -> $crate::params::ParamsBase<$view, D, K> + where + $($rest)* + { + ndview!(@apply $call(self)) + } + }; + (@impl $method:ident.$call:ident::<$view:ident>(mut self) where $($rest:tt)*) => { + pub fn $method(mut self) -> $crate::params::ParamsBase<$view, D, K> + where + $($rest)* + { + ndview!(@apply $call(self).as_mut()) + } + }; + (@impl $method:ident.$call:ident::<$view:ident>(&self) where $($rest:tt)*) => { + pub fn $method(&self) -> $crate::params::ParamsBase<$view, D, K> + where + $($rest)* + { + ndview!(@apply $call(self).as_ref()) + } + }; + (@impl $method:ident.$call:ident::<$view:ident>(&mut self) where $($rest:tt)*) => { + pub fn $method(&mut self) -> $crate::params::ParamsBase<$view, D, K> + where + $($rest)* + { + ndview!(@apply $call(self).as_mut()) + } + }; + (@impl $method:ident.$call:ident::<'a, $view:ident>(&self) where $($rest:tt)*) => { + pub fn $method(&self) -> $crate::params::ParamsBase<$view<&'_ A>, D, K> + where + $($rest)* + { + ndview!(@apply $call(&self).as_ref()) + } + }; + (@impl $method:ident.$call:ident::<'a, $view:ident>(&mut self) where $($rest:tt)*) => { + pub fn $method(&mut self) -> $crate::params::ParamsBase<$view<&'_ mut A>, D, K> + where + $($rest)* + { + ndview!(@apply $call(self).as_mut()) + } + }; + (@apply $call:ident($self:expr)$(.$as:ident())?) => { + $crate::params::ParamsBase { + bias: $self.bias$(.$as())?.map(|arr| arr.$call()), + weights: $self.weights.$call(), + _mode: $self._mode, + } + }; +} diff --git a/models/linear/src/params/params.rs b/models/linear/src/params/params.rs index fab0a7b6..d50f5c36 100644 --- a/models/linear/src/params/params.rs +++ b/models/linear/src/params/params.rs @@ -30,16 +30,6 @@ where impl_param_builder!(ones where A: Clone + One, S: DataOwned); impl_param_builder!(zeros where A: Clone + Zero, S: DataOwned); - #[doc(hidden)] - pub fn build(shape: Sh, builder: F) -> Self - where - F: Fn(Sh) -> ArrayBase, - Sh: ShapeBuilder, - { - let _weights = builder(shape); - unimplemented!() - } - pub fn into_biased(self) -> ParamsBase where A: Default, @@ -107,6 +97,17 @@ where pub fn shape(&self) -> &[usize] { self.weights().shape() } + ndview!(into_owned::(self) where A: Clone, S: Data); + + ndview!(into_shared::(self) where A: Clone, S: DataOwned); + + ndview!(to_owned::(&self) where A: Clone, S: Data); + + ndview!(to_shared::(&self) where A: Clone, S: DataOwned); + + ndview!(view::<'a, ViewRepr>(&self) where A: Clone, S: Data); + + ndview!(view_mut::<'a, ViewRepr>(&mut self) where A: Clone, S: DataMut); } impl ParamsBase From e8d33f89670aa166f53dee86e1c747698c3bf734 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sun, 12 May 2024 14:00:46 -0500 Subject: [PATCH 03/23] update Signed-off-by: Joe McCain III --- concision/benches/default.rs | 105 ++++++++++++------ core/Cargo.toml | 11 +- core/src/nn/mod.rs | 6 +- core/src/nn/model.rs | 29 +++++ core/src/nn/model/config.rs | 18 +++ core/src/nn/{models => model}/module.rs | 2 +- core/src/nn/models/mod.rs | 13 --- core/src/nn/models/model.rs | 21 ---- models/linear/src/impls/impl_rand.rs | 6 +- models/linear/src/impls/model/impl_linear.rs | 16 +-- models/linear/src/impls/model/impl_model.rs | 8 +- models/linear/src/impls/params/impl_from.rs | 5 +- models/linear/src/impls/params/impl_params.rs | 11 +- models/linear/src/impls/params/impl_serde.rs | 3 +- models/linear/src/lib.rs | 7 +- models/linear/src/macros.rs | 1 + models/linear/src/mlp/mod.rs | 9 +- models/linear/src/model/config.rs | 9 +- models/linear/src/model/layout/layout.rs | 5 +- models/linear/src/model/linear.rs | 16 +-- models/linear/src/norm/layer.rs | 52 +++++++++ models/linear/src/norm/mod.rs | 14 +++ models/linear/src/params/mod.rs | 41 +------ models/linear/src/params/mode.rs | 13 ++- models/linear/src/params/params.rs | 6 +- models/linear/src/primitives.rs | 28 +++++ models/linear/tests/params.rs | 23 +++- models/transformers/Cargo.toml | 4 +- models/transformers/src/attention/head.rs | 6 +- models/transformers/src/macros.rs | 41 +++---- models/transformers/src/params/store.rs | 18 +-- 31 files changed, 327 insertions(+), 220 deletions(-) create mode 100644 core/src/nn/model.rs create mode 100644 core/src/nn/model/config.rs rename core/src/nn/{models => model}/module.rs (93%) delete mode 100644 core/src/nn/models/mod.rs delete mode 100644 core/src/nn/models/model.rs create mode 100644 models/linear/src/norm/layer.rs create mode 100644 models/linear/src/norm/mod.rs create mode 100644 models/linear/src/primitives.rs diff --git a/concision/benches/default.rs b/concision/benches/default.rs index 937f2387..0dca58bc 100644 --- a/concision/benches/default.rs +++ b/concision/benches/default.rs @@ -3,50 +3,89 @@ extern crate test; -use std::mem::replace; use test::Bencher; // bench: find the `BENCH_SIZE` first terms of the fibonacci sequence -static BENCH_SIZE: usize = 20; - -// recursive fibonacci -fn fibonacci(n: usize) -> u32 { - if n < 2 { - 1 - } else { - fibonacci(n - 1) + fibonacci(n - 2) - } -} - -// iterative fibonacci -struct Fibonacci { - curr: u32, - next: u32, -} - -impl Iterator for Fibonacci { - type Item = u32; - fn next(&mut self) -> Option { - let new_next = self.curr + self.next; - let new_curr = replace(&mut self.next, new_next); +const BENCH_SIZE: u32 = 20; - Some(replace(&mut self.curr, new_curr)) - } +#[bench] +fn fibonacci(b: &mut Bencher) { + // exact code to benchmark must be passed as a closure to the iter + // method of Bencher + b.iter(|| (0..BENCH_SIZE).map(fib::fibonacci).collect::>()) } -fn fibonacci_sequence() -> Fibonacci { - Fibonacci { curr: 1, next: 1 } +#[bench] +fn iter_fibonacci(b: &mut Bencher) { + b.iter(|| fib::Fibonacci::new().take(BENCH_SIZE as usize).collect::>()) } -// function to benchmark must be annotated with `#[bench]` #[bench] fn recursive_fibonacci(b: &mut Bencher) { // exact code to benchmark must be passed as a closure to the iter // method of Bencher - b.iter(|| (0..BENCH_SIZE).map(fibonacci).collect::>()) + b.iter(|| (0..BENCH_SIZE).map(fib::recursive_fibonacci).collect::>()) } -#[bench] -fn iterative_fibonacci(b: &mut Bencher) { - b.iter(|| fibonacci_sequence().take(BENCH_SIZE).collect::>()) -} +mod fib { + /// fibonacci(n) returns the nth fibonacci number + /// This function uses the definition of Fibonacci where: + /// F(0) = F(1) = 1 and F(n+1) = F(n) + F(n-1) for n>0 + /// + /// Warning: This will overflow the 128-bit unsigned integer at n=186 + pub fn fibonacci(n: u32) -> u128 { + // Use a and b to store the previous two values in the sequence + let mut a = 0; + let mut b = 1; + for _i in 0..n { + // As we iterate through, move b's value into a and the new computed + // value into b. + let c = a + b; + a = b; + b = c; + } + b + } + + /// fibonacci(n) returns the nth fibonacci number + /// This function uses the definition of Fibonacci where: + /// F(0) = F(1) = 1 and F(n+1) = F(n) + F(n-1) for n>0 + /// + /// Warning: This will overflow the 128-bit unsigned integer at n=186 + pub fn recursive_fibonacci(n: u32) -> u128 { + // Call the actual tail recursive implementation, with the extra + // arguments set up. + _recursive_fibonacci(n, 0, 1) + } + + fn _recursive_fibonacci(n: u32, previous: u128, current: u128) -> u128 { + if n == 0 { + current + } else { + _recursive_fibonacci(n - 1, current, current + previous) + } + } + + pub struct Fibonacci { + curr: u32, + next: u32, + } + + impl Fibonacci { + pub fn new() -> Fibonacci { + Fibonacci { curr: 0, next: 1 } + } + } + + impl Iterator for Fibonacci { + type Item = u32; + + fn next(&mut self) -> Option { + use core::mem::replace; + let new_next = self.curr + self.next; + let new_curr = replace(&mut self.next, new_next); + + Some(replace(&mut self.curr, new_curr)) + } + } +} \ No newline at end of file diff --git a/core/Cargo.toml b/core/Cargo.toml index d112087b..0a847b31 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -56,7 +56,7 @@ rand-ext = [ "uuid/v4", ] -rng_std = [ +std-rng = [ "rand?/std", "rand?/std_rng", ] @@ -82,16 +82,18 @@ tracing = [ # ********* [FF] Environments ********* std = [ "alloc", + "std-rng", "ndarray/std", "num/std", - "rng_std", "scsys/std", "serde/std", "strum/std", "uuid/std" ] -wasm = [] +wasm = [ + "getrandom/js", +] wasi = [] @@ -158,6 +160,7 @@ lazy_static = "1" all-features = true rustc-args = ["--cfg", "docsrs"] -[target.wasm32-unknown-unknown] +[target.wasm32-unknown-unknown.dependencies] +getrandom = "0.2" [target.wasm32-wasi] diff --git a/core/src/nn/mod.rs b/core/src/nn/mod.rs index b82d7727..b5eb0d61 100644 --- a/core/src/nn/mod.rs +++ b/core/src/nn/mod.rs @@ -2,14 +2,14 @@ Appellation: nn Contrib: FL03 */ -pub use self::{error::ModelError, models::Module}; +pub use self::{error::ModelError, model::prelude::*}; pub mod error; -pub mod models; +pub mod model; pub(crate) mod prelude { pub use super::error::ModelError; - pub use super::models::prelude::*; + pub use super::model::prelude::*; } #[cfg(test)] diff --git a/core/src/nn/model.rs b/core/src/nn/model.rs new file mode 100644 index 00000000..c80c74ca --- /dev/null +++ b/core/src/nn/model.rs @@ -0,0 +1,29 @@ +/* + Appellation: model + Contrib: FL03 +*/ +pub use self::module::*; + +pub mod config; +pub mod module; + +pub(crate) mod prelude { + pub use super::config::*; + pub use super::module::*; + pub use super::Model; +} + +use crate::traits::Forward; + +pub trait Model: Module +where + Self: Forward, +{ + type Ctx; + type Data; + + fn children(&self) -> Vec>; + + fn context(&self) -> Self::Ctx; +} + diff --git a/core/src/nn/model/config.rs b/core/src/nn/model/config.rs new file mode 100644 index 00000000..7596c36b --- /dev/null +++ b/core/src/nn/model/config.rs @@ -0,0 +1,18 @@ +/* + Appellation: config + Contrib: FL03 +*/ +use crate::traits::Config; + +pub struct ModelConfig { + pub name: String, + _children: Vec>, +} + +impl Config for ModelConfig {} + +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize,))] +pub struct ConfigBase { + pub id: usize, + pub name: &'static str, +} diff --git a/core/src/nn/models/module.rs b/core/src/nn/model/module.rs similarity index 93% rename from core/src/nn/models/module.rs rename to core/src/nn/model/module.rs index 79fbe8ed..8d8ee23b 100644 --- a/core/src/nn/models/module.rs +++ b/core/src/nn/model/module.rs @@ -4,7 +4,7 @@ */ use crate::{Config, Predict}; -pub type DynModule = Box>; +pub type ModuleDyn = Box>; pub type DynModuleExt = Box>; pub type Stack = Vec>>; diff --git a/core/src/nn/models/mod.rs b/core/src/nn/models/mod.rs deleted file mode 100644 index 37e9ebe7..00000000 --- a/core/src/nn/models/mod.rs +++ /dev/null @@ -1,13 +0,0 @@ -/* - Appellation: models - Contrib: FL03 -*/ -pub use self::{model::*, module::*}; - -pub mod model; -pub mod module; - -pub(crate) mod prelude { - pub use super::model::*; - pub use super::module::*; -} diff --git a/core/src/nn/models/model.rs b/core/src/nn/models/model.rs deleted file mode 100644 index 00736aaf..00000000 --- a/core/src/nn/models/model.rs +++ /dev/null @@ -1,21 +0,0 @@ -/* - Appellation: model - Contrib: FL03 -*/ -use super::{DynModule, Module}; -use crate::traits::Forward; - -pub trait Model: Module -where - Self: Forward, -{ - type Ctx; - type Data; - - fn children(&self) -> Vec>; -} - -pub struct ConfigBase { - pub id: usize, - pub name: String, -} diff --git a/models/linear/src/impls/impl_rand.rs b/models/linear/src/impls/impl_rand.rs index 8b441fcc..227a976d 100644 --- a/models/linear/src/impls/impl_rand.rs +++ b/models/linear/src/impls/impl_rand.rs @@ -4,7 +4,7 @@ */ #![cfg(feature = "rand")] -use crate::params::{ParamMode, ParamsBase}; +use crate::params::ParamsBase; use crate::{bias_dim, Linear}; use concision::prelude::GenerateRandom; use concision::rand::rand_distr::{uniform, Distribution, StandardNormal}; @@ -15,10 +15,9 @@ impl Linear where A: Float + uniform::SampleUniform, D: RemoveAxis, - K: ParamMode, StandardNormal: Distribution, { - pub fn uniform(self) -> Self { + pub fn uniform(self) -> Self where K: 'static { let biased = self.is_biased(); Self { params: self.params.init_uniform(biased), @@ -31,7 +30,6 @@ impl ParamsBase, D, K> where A: Float + uniform::SampleUniform, D: RemoveAxis, - K: ParamMode, StandardNormal: Distribution, { pub(crate) fn dk(&self) -> A { diff --git a/models/linear/src/impls/model/impl_linear.rs b/models/linear/src/impls/model/impl_linear.rs index e182436b..762c8fea 100644 --- a/models/linear/src/impls/model/impl_linear.rs +++ b/models/linear/src/impls/model/impl_linear.rs @@ -6,10 +6,7 @@ use crate::{Config, Linear, LinearParams, ParamMode}; use core::borrow::{Borrow, BorrowMut}; use nd::RemoveAxis; -impl Linear -where - K: ParamMode, -{ +impl Linear where K: ParamMode { pub fn from_features(inputs: usize, outputs: usize) -> Self where A: Clone + Default, @@ -23,29 +20,26 @@ where impl Borrow> for Linear where D: RemoveAxis, - K: ParamMode, { fn borrow(&self) -> &Config { &self.config } } -impl Borrow> for Linear +impl Borrow> for Linear where D: RemoveAxis, - K: ParamMode, { - fn borrow(&self) -> &LinearParams { + fn borrow(&self) -> &LinearParams { &self.params } } -impl BorrowMut> for Linear +impl BorrowMut> for Linear where D: RemoveAxis, - K: ParamMode, { - fn borrow_mut(&mut self) -> &mut LinearParams { + fn borrow_mut(&mut self) -> &mut LinearParams { &mut self.params } } diff --git a/models/linear/src/impls/model/impl_model.rs b/models/linear/src/impls/model/impl_model.rs index d61633f1..475c66bd 100644 --- a/models/linear/src/impls/model/impl_model.rs +++ b/models/linear/src/impls/model/impl_model.rs @@ -2,17 +2,16 @@ Appellation: impl_model Contrib: FL03 */ -use crate::{Config, Linear, LinearParams, ParamMode}; +use crate::{Config, Linear, LinearParams}; use concision::prelude::{Module, Predict, PredictError}; use nd::RemoveAxis; impl Module for Linear where D: RemoveAxis, - K: ParamMode, { type Config = Config; - type Params = LinearParams; + type Params = LinearParams; fn config(&self) -> &Self::Config { &self.config @@ -30,8 +29,7 @@ where impl Predict for Linear where D: RemoveAxis, - K: ParamMode, - LinearParams: Predict, + LinearParams: Predict, { type Output = V; diff --git a/models/linear/src/impls/params/impl_from.rs b/models/linear/src/impls/params/impl_from.rs index 2ea9df1e..5564d08b 100644 --- a/models/linear/src/impls/params/impl_from.rs +++ b/models/linear/src/impls/params/impl_from.rs @@ -2,8 +2,7 @@ Appellation: impl_from Contrib: FL03 */ -use crate::params::{Biased, NodeBase, Pair, ParamMode, ParamsBase, Unbiased}; -use crate::Features; +use crate::{Biased, Features, NodeBase, Pair, ParamsBase, Unbiased}; #[cfg(all(feature = "alloc", no_std))] use alloc::vec; use core::marker::PhantomData; @@ -100,7 +99,6 @@ where impl From<(Array1, Option)> for ParamsBase, Ix1, K> where A: Clone, - K: ParamMode, { fn from((weights, bias): (Array1, Option)) -> Self { Self { @@ -114,7 +112,6 @@ where impl From> for ParamsBase where D: RemoveAxis, - K: ParamMode, S: RawData, { fn from((weights, bias): NodeBase) -> Self { diff --git a/models/linear/src/impls/params/impl_params.rs b/models/linear/src/impls/params/impl_params.rs index 1f29c8d3..dc555d46 100644 --- a/models/linear/src/impls/params/impl_params.rs +++ b/models/linear/src/impls/params/impl_params.rs @@ -2,7 +2,6 @@ Appellation: params Contrib: FL03 */ -use crate::params::mode::*; use crate::params::ParamsBase; use concision::prelude::{Predict, PredictError}; use core::ops::Add; @@ -13,7 +12,6 @@ use num::complex::ComplexFloat; impl ParamsBase where D: RemoveAxis, - K: ParamMode, S: RawData, { pub fn activate(&mut self, args: &X, f: F) -> Y @@ -31,7 +29,6 @@ where A: Dot, Output = B>, B: Add<&'a ArrayBase, Output = B>, D: RemoveAxis, - K: ParamMode, S: Data, T: NdFloat, { @@ -118,7 +115,6 @@ macro_rules! impl_predict { A: Dot, Output = B>, B: for<'a> Add<&'a ArrayBase, Output = B>, D: RemoveAxis, - K: ParamMode, S: Data, T: ComplexFloat, { @@ -134,13 +130,12 @@ macro_rules! impl_predict { } } }; - (@impl $lt:lifetime $name:ident) => { - impl<'a, A, B, T, S, D, K> Predict for $name + (@impl $name:ident<&'a $rhs:ident>) => { + impl<'a, A, B, T, S, D, K> Predict<&'a $rhs> for $name where A: Dot, Output = B>, - B: for<'a> Add<&'a ArrayBase, Output = B>, + B: Add<&'a ArrayBase, Output = B>, D: RemoveAxis, - K: ParamMode, S: Data, T: ComplexFloat, { diff --git a/models/linear/src/impls/params/impl_serde.rs b/models/linear/src/impls/params/impl_serde.rs index 9d0b77b5..2d8641fa 100644 --- a/models/linear/src/impls/params/impl_serde.rs +++ b/models/linear/src/impls/params/impl_serde.rs @@ -4,7 +4,7 @@ */ #![cfg(feature = "serde")] -use crate::params::{Entry, ParamMode, ParamsBase}; +use crate::params::{Entry, ParamsBase}; use core::marker::PhantomData; use nd::*; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -33,7 +33,6 @@ impl Serialize for ParamsBase where A: Serialize, D: RemoveAxis + Serialize, - K: ParamMode, S: Data, ::Smaller: Dimension + Serialize, { diff --git a/models/linear/src/lib.rs b/models/linear/src/lib.rs index 81709daa..aa3a6b57 100644 --- a/models/linear/src/lib.rs +++ b/models/linear/src/lib.rs @@ -16,12 +16,13 @@ extern crate concision_core as concision; extern crate ndarray as nd; pub use self::model::{Config, Features, Layout, Linear}; -pub use self::params::{mode::*, LinearParams}; +pub use self::params::{mode::*, ParamsBase}; #[allow(unused_imports)] -pub use self::{traits::*, utils::*}; +pub use self::{traits::*, primitives::*, utils::*}; #[macro_use] pub(crate) mod macros; +pub(crate) mod primitives; #[macro_use] pub(crate) mod seal; pub(crate) mod utils; @@ -33,6 +34,7 @@ pub mod dense; #[doc(hidden)] pub mod mlp; pub mod model; +pub mod norm; pub mod params; pub mod traits; @@ -53,6 +55,7 @@ mod impls { pub mod prelude { pub use crate::model::prelude::*; + pub use crate::norm::prelude::*; pub use crate::params::prelude::*; pub use crate::traits::*; } diff --git a/models/linear/src/macros.rs b/models/linear/src/macros.rs index 5bfa0499..ab2ee6eb 100644 --- a/models/linear/src/macros.rs +++ b/models/linear/src/macros.rs @@ -10,6 +10,7 @@ macro_rules! impl_param_builder { (@impl $call:ident where $($rest:tt)*) => { pub fn $call(shape: Sh) -> Self where + K: $crate::params::mode::ParamMode, Sh: ndarray::ShapeBuilder, $($rest)* { diff --git a/models/linear/src/mlp/mod.rs b/models/linear/src/mlp/mod.rs index 218ee165..715d7681 100644 --- a/models/linear/src/mlp/mod.rs +++ b/models/linear/src/mlp/mod.rs @@ -18,6 +18,11 @@ pub trait MultiLayerPerceptron { type Output; } -pub trait Neuron { - type Rho; +pub trait Neuron { +} + +pub trait Rho { + type Output; + + fn activate(&self, args: T) -> Self::Output; } diff --git a/models/linear/src/model/config.rs b/models/linear/src/model/config.rs index 6f4f40d8..c2dce962 100644 --- a/models/linear/src/model/config.rs +++ b/models/linear/src/model/config.rs @@ -3,16 +3,13 @@ Contrib: FL03 */ use super::layout::{Features, Layout}; -use crate::params::mode::*; +use crate::params::{Biased, Unbiased}; use core::marker::PhantomData; use nd::{Dimension, IntoDimension, Ix2, RemoveAxis}; #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] -pub struct Config -where - D: Dimension, -{ +pub struct Config{ pub layout: Layout, pub name: String, _biased: PhantomData, @@ -21,7 +18,6 @@ where impl Config where D: Dimension, - K: ParamMode, { pub fn new() -> Self { Self { @@ -157,7 +153,6 @@ where impl concision::Config for Config where D: Dimension, - K: ParamMode, { } diff --git a/models/linear/src/model/layout/layout.rs b/models/linear/src/model/layout/layout.rs index d4ce877a..a2a4d83b 100644 --- a/models/linear/src/model/layout/layout.rs +++ b/models/linear/src/model/layout/layout.rs @@ -9,10 +9,7 @@ use nd::{Dimension, RemoveAxis, ShapeBuilder, ShapeError}; #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] -pub struct Layout -where - D: Dimension, -{ +pub struct Layout { pub(crate) dim: D, pub(crate) features: Features, } diff --git a/models/linear/src/model/linear.rs b/models/linear/src/model/linear.rs index b10bca02..5291e6ee 100644 --- a/models/linear/src/model/linear.rs +++ b/models/linear/src/model/linear.rs @@ -17,18 +17,17 @@ where D: Dimension, { pub(crate) config: Config, - pub(crate) params: LinearParams, + pub(crate) params: LinearParams, } impl Linear where D: RemoveAxis, - K: ParamMode, { pub fn from_config(config: Config) -> Self where A: Clone + Default, - K: 'static, + K: ParamMode, { let params = LinearParams::default(config.dim()); Self { config, params } @@ -37,13 +36,14 @@ where pub fn from_layout(layout: Layout) -> Self where A: Clone + Default, + K: ParamMode, { let config = Config::::new().with_layout(layout); let params = LinearParams::default(config.dim()); Self { config, params } } - pub fn with_params(self, params: LinearParams) -> Linear + pub fn with_params(self, params: LinearParams) -> Linear where E: RemoveAxis, { @@ -71,11 +71,11 @@ where self.params.weights_mut() } - pub const fn params(&self) -> &LinearParams { + pub const fn params(&self) -> &LinearParams { &self.params } - pub fn params_mut(&mut self) -> &mut LinearParams { + pub fn params_mut(&mut self) -> &mut LinearParams { &mut self.params } @@ -89,8 +89,8 @@ where } } - pub fn is_biased(&self) -> bool { - K::BIASED || self.config().is_biased() + pub fn is_biased(&self) -> bool where K: 'static { + self.config().is_biased() } pub fn with_name(self, name: impl ToString) -> Self { diff --git a/models/linear/src/norm/layer.rs b/models/linear/src/norm/layer.rs new file mode 100644 index 00000000..44f3726a --- /dev/null +++ b/models/linear/src/norm/layer.rs @@ -0,0 +1,52 @@ +/* + Appellation: layer + Contrib: FL03 +*/ +use crate::{LinearParams, ParamMode}; +use nd::RemoveAxis; +use nd::prelude::*; + +pub struct LayerNorm where D: Dimension, { + config: LayerNormConfig, + params: LinearParams, +} + +pub struct LayerNormConfig { + pub dim: D, + pub eps: f64, +} + +impl LayerNormConfig { + pub fn new() -> Self where D: Default { + Self { dim: D::default(), eps: 1e-5 } + } + + pub fn create(dim: D, eps: f64) -> Self where D: Default { + Self { dim, eps } + } + + pub fn with_dim(dim: D) -> Self { + Self { dim, eps: 1e-5 } + } + +} + +impl Default for LayerNormConfig where D: Default { + fn default() -> Self { + Self { + dim: D::default(), + eps: 1e-5, + } + } +} + + + +impl LayerNorm where D: RemoveAxis, K: ParamMode, { + pub fn from_shape(shape: Sh) -> Self where A: Default, Sh: ShapeBuilder { + let dim = shape.into_shape().raw_dim().clone(); + let config = LayerNormConfig::with_dim(dim.clone()); + let params = LinearParams::::default(dim); + Self { config, params } + } +} \ No newline at end of file diff --git a/models/linear/src/norm/mod.rs b/models/linear/src/norm/mod.rs new file mode 100644 index 00000000..5be2e4e9 --- /dev/null +++ b/models/linear/src/norm/mod.rs @@ -0,0 +1,14 @@ +/* + Appellation: norm + Contrib: FL03 +*/ +//! # Normalization +//! +//! +pub use self::layer::LayerNorm; + +pub mod layer; + +pub(crate) mod prelude { + pub use super::layer::LayerNorm; +} \ No newline at end of file diff --git a/models/linear/src/params/mod.rs b/models/linear/src/params/mod.rs index b5d72583..15b3761e 100644 --- a/models/linear/src/params/mod.rs +++ b/models/linear/src/params/mod.rs @@ -4,7 +4,7 @@ */ #[doc(inline)] pub use self::entry::{Entry, Param}; -pub use self::mode::{Biased, ParamMode, Unbiased}; +pub use self::mode::*; pub use self::params::ParamsBase; mod params; @@ -12,44 +12,11 @@ mod params; pub mod entry; pub mod mode; -use nd::{ArrayBase, Ix0, Ix1}; - -pub(crate) type Pair = (A, B); -pub(crate) type MaybePair = Pair>; -pub(crate) type NodeBase = MaybePair, ArrayBase>; -pub(crate) type Node = NodeBase, D, E>; - -macro_rules! params_ty { - ($($name:ident<$repr:ident>),* $(,)?) => { - $(params_ty!(@impl $name<$repr>);)* - }; - (@impl $name:ident<$repr:ident>) => { - pub type $name = $crate::params::ParamsBase, D, K>; - }; -} +#[doc(inline)] +pub use crate::primitives::params::*; -params_ty!(LinearParams, LinearParamsShared,); pub(crate) mod prelude { - pub use super::{LinearParams, LinearParamsShared}; + pub use super::mode::*; } -#[cfg(test)] -mod tests { - use super::*; - use core::str::FromStr; - - #[test] - fn test_param_kind() { - for i in [(Param::Bias, "bias"), (Param::Weight, "weight")].iter() { - let kind = Param::from_str(i.1).unwrap(); - assert_eq!(i.0, kind); - } - } - - #[test] - fn test_ones() { - let a = LinearParams::::ones((1, 300)); - assert!(a.is_biased()); - } -} diff --git a/models/linear/src/params/mode.rs b/models/linear/src/params/mode.rs index 53e16726..e8ca29c0 100644 --- a/models/linear/src/params/mode.rs +++ b/models/linear/src/params/mode.rs @@ -3,9 +3,6 @@ Contrib: FL03 */ -pub trait State { - type Mode: ParamMode; -} pub trait ParamMode: 'static { const BIASED: bool = false; @@ -17,6 +14,10 @@ pub trait ParamMode: 'static { private!(); } +/* + ************* Implementations ************* +*/ + impl ParamMode for Option where T: 'static, @@ -30,9 +31,9 @@ where seal!(); } -macro_rules! param_mode { +macro_rules! mode { {$($T:ident: $opt:expr),* $(,)?} => { - $(param_mode!(@impl $T: $opt);)* + $(mode!(@impl $T: $opt);)* }; (@impl $T:ident: $opt:expr) => { #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] @@ -52,7 +53,7 @@ macro_rules! param_mode { } -param_mode! { +mode! { Biased: true, Unbiased: false, } diff --git a/models/linear/src/params/params.rs b/models/linear/src/params/params.rs index d50f5c36..64386695 100644 --- a/models/linear/src/params/params.rs +++ b/models/linear/src/params/params.rs @@ -2,9 +2,7 @@ Appellation: params Contrib: FL03 */ -use super::mode::*; -use super::Node; -use crate::{build_bias, Features}; +use crate::{build_bias, Biased, Features, Node, Unbiased}; use core::marker::PhantomData; use nd::*; use num::{One, Zero}; @@ -23,7 +21,6 @@ where impl ParamsBase where D: RemoveAxis, - K: ParamMode, S: RawData, { impl_param_builder!(default where A: Default, S: DataOwned); @@ -152,7 +149,6 @@ where } impl ParamsBase where - K: ParamMode, S: RawData, { pub fn set_node(&mut self, idx: usize, node: Node) diff --git a/models/linear/src/primitives.rs b/models/linear/src/primitives.rs new file mode 100644 index 00000000..aa1e8137 --- /dev/null +++ b/models/linear/src/primitives.rs @@ -0,0 +1,28 @@ +/* + Appellation: primitives + Contrib: FL03 +*/ +pub use self::params::*; +use nd::{ArrayBase, Ix0, Ix1}; + +pub(crate) type Pair = (A, B); + +pub(crate) type MaybePair = Pair>; + +pub(crate) type NodeBase = MaybePair, ArrayBase>; + +pub(crate) type Node = NodeBase, D, E>; + +pub(crate) mod params { + + macro_rules! params_ty { + ($($name:ident<$repr:ident>),* $(,)?) => { + $(params_ty!(@impl $name<$repr>);)* + }; + (@impl $name:ident<$repr:ident>) => { + pub type $name = $crate::params::ParamsBase, D, K>; + }; + } + + params_ty!(LinearParams, LinearParamsShared,); +} \ No newline at end of file diff --git a/models/linear/tests/params.rs b/models/linear/tests/params.rs index acfaa093..56af0d6a 100644 --- a/models/linear/tests/params.rs +++ b/models/linear/tests/params.rs @@ -6,13 +6,24 @@ extern crate concision_core as concision; extern crate concision_linear as linear; use concision::Predict; -use linear::{Features, LinearParams}; +use linear::Features; +use linear::params::{LinearParams, Param, Unbiased}; + +use core::str::FromStr; use ndarray::prelude::*; const SAMPLES: usize = 20; const INPUTS: usize = 5; const DMODEL: usize = 3; +#[test] +fn test_keys() { + for i in [(Param::Bias, "bias"), (Param::Weight, "weight")].iter() { + let kind = Param::from_str(i.1).unwrap(); + assert_eq!(i.0, kind); + } +} + #[test] fn test_linear_params() { let (samples, inputs, outputs) = (SAMPLES, INPUTS, DMODEL); @@ -22,3 +33,13 @@ fn test_linear_params() { let y: Array2 = params.predict(&data).unwrap(); assert_eq!(y.dim(), (samples, outputs)); } + +#[test] +fn test_ndbuilders() { + let shape = (300, 10); + let params = LinearParams::::ones(shape); + assert!(params.is_biased()); + let params = LinearParams::::zeros(shape); + assert!(!params.is_biased()); + +} \ No newline at end of file diff --git a/models/transformers/Cargo.toml b/models/transformers/Cargo.toml index 34782d1c..7e91e5e3 100644 --- a/models/transformers/Cargo.toml +++ b/models/transformers/Cargo.toml @@ -112,8 +112,8 @@ version = "1" optional = true version = "0.1" -[dev-dependencies] -lazy_static.workspace = true +[dev-dependencies.lazy_static] +workspace = true [package.metadata.docs.rs] all-features = true diff --git a/models/transformers/src/attention/head.rs b/models/transformers/src/attention/head.rs index f3bb9faa..71494a18 100644 --- a/models/transformers/src/attention/head.rs +++ b/models/transformers/src/attention/head.rs @@ -48,7 +48,7 @@ where } access!(params::); - ndbuilder!(new.default where A: Default, S: DataOwned); - ndbuilder!(ones where A: Clone + num::One, S: DataOwned); - ndbuilder!(zeros where A: Clone + num::Zero, S: DataOwned); + ndbuilder!(new::default() where A: Default, S: DataOwned); + ndbuilder!(ones() where A: Clone + num::One, S: DataOwned); + ndbuilder!(zeros() where A: Clone + num::Zero, S: DataOwned); } diff --git a/models/transformers/src/macros.rs b/models/transformers/src/macros.rs index f4712afc..37e4f3e3 100644 --- a/models/transformers/src/macros.rs +++ b/models/transformers/src/macros.rs @@ -33,42 +33,33 @@ macro_rules! access { } macro_rules! ndbuilder { - ($method:ident $($rest:tt)*) => { - ndbuilder!(@impl $method $($rest)*); + ($method:ident$(::$call:ident)?() where $($rest:tt)*) => { + ndbuilder!(@impl $method$(::$call)?() where $($rest)*); }; - (@impl $method:ident where $($rest:tt)*) => { - pub fn $method(shape: Sh) -> Self - where - Sh: ndarray::ShapeBuilder, - $($rest)* - { - Self::builder(shape, ArrayBase::$method) - } + (@impl $method:ident() where $($rest:tt)*) => { + ndbuilder!(@impl $method::$method() where $($rest)*); }; - (@impl $method:ident.$call:ident where $($rest:tt)*) => { - pub fn $method(shape: Sh) -> Self - where - Sh: ndarray::ShapeBuilder, - $($rest)* - { + (@impl $method:ident::$call:ident() where $($rest:tt)*) => { + pub fn $method>(shape: Sh) -> Self where $($rest)* { Self::builder(shape, ArrayBase::$call) } }; } -macro_rules! param_views { +// # TODO: +macro_rules! ndview { ($method:ident::$($rest:tt)*) => { - param_views!(@impl $method.$method::$($rest)*); + ndview!(@impl $method.$method::$($rest)*); }; ($method:ident.$call:ident::$($rest:tt)*) => { - param_views!(@impl $method.$call::$($rest)*); + ndview!(@impl $method.$call::$($rest)*); }; (@impl $method:ident.$call:ident::<$view:ident>(self) where $($rest:tt)*) => { pub fn $method(self) -> $crate::params::ParamsBase<$view, D> where $($rest)* { - param_views!(@apply $call(self)) + ndview!(@apply $call(self)) } }; (@impl $method:ident.$call:ident::<$view:ident>(mut self) where $($rest:tt)*) => { @@ -76,7 +67,7 @@ macro_rules! param_views { where $($rest)* { - param_views!(@apply $call(self)) + ndview!(@apply $call(self)) } }; (@impl $method:ident.$call:ident::<$view:ident>(&self) where $($rest:tt)*) => { @@ -84,7 +75,7 @@ macro_rules! param_views { where $($rest)* { - param_views!(@apply $call(self)) + ndview!(@apply $call(self)) } }; (@impl $method:ident.$call:ident::<$view:ident>(&mut self) where $($rest:tt)*) => { @@ -92,7 +83,7 @@ macro_rules! param_views { where $($rest)* { - param_views!(@apply $call(self)) + ndview!(@apply $call(self)) } }; (@impl $method:ident.$call:ident::<'a, $view:ident>(&self) where $($rest:tt)*) => { @@ -100,7 +91,7 @@ macro_rules! param_views { where $($rest)* { - param_views!(@apply $call(self)) + ndview!(@apply $call(self)) } }; (@impl $method:ident.$call:ident::<'a, $view:ident>(&mut self) where $($rest:tt)*) => { @@ -108,7 +99,7 @@ macro_rules! param_views { where $($rest)* { - param_views!(@apply $call(self)) + ndview!(@apply $call(self)) } }; (@apply $call:ident($self:expr)) => { diff --git a/models/transformers/src/params/store.rs b/models/transformers/src/params/store.rs index 6fcdca1b..78f49bd3 100644 --- a/models/transformers/src/params/store.rs +++ b/models/transformers/src/params/store.rs @@ -34,9 +34,9 @@ where } } - ndbuilder!(new.default where A: Default, S: DataOwned); - ndbuilder!(ones where A: Clone + One, S: DataOwned); - ndbuilder!(zeros where A: Clone + Zero, S: DataOwned); + ndbuilder!(new::default() where A: Default, S: DataOwned); + ndbuilder!(ones() where A: Clone + One, S: DataOwned); + ndbuilder!(zeros() where A: Clone + Zero, S: DataOwned); access!(q, k, v); @@ -84,12 +84,12 @@ where pub fn shape(&self) -> &[usize] { self.q.shape() } - param_views!(into_owned::(self) where A: Clone, S: Data); - param_views!(to_owned::(&self) where A: Clone, S: Data); + ndview!(into_owned::(self) where A: Clone, S: Data); + ndview!(to_owned::(&self) where A: Clone, S: Data); - param_views!(into_shared::(self) where A: Clone, S: DataOwned); - param_views!(to_shared::(&self) where A: Clone, S: DataShared); + ndview!(into_shared::(self) where A: Clone, S: DataOwned); + ndview!(to_shared::(&self) where A: Clone, S: DataShared); - param_views!(view::<'a, ViewRepr>(&self) where S: Data); - param_views!(view_mut::<'a, ViewRepr>(&mut self) where S: DataMut); + ndview!(view::<'a, ViewRepr>(&self) where S: Data); + ndview!(view_mut::<'a, ViewRepr>(&mut self) where S: DataMut); } From 5c944319ce4d7cf2b918ee7c35d3f35e21c1fcf4 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sun, 12 May 2024 14:34:34 -0500 Subject: [PATCH 04/23] update Signed-off-by: Joe McCain III --- concision/benches/default.rs | 16 +++++-- core/src/macros.rs | 43 +++++++++++++++++++ core/src/nn/model.rs | 1 - models/linear/src/impls/impl_rand.rs | 5 ++- models/linear/src/impls/model/impl_linear.rs | 5 ++- models/linear/src/lib.rs | 2 +- models/linear/src/mlp/mod.rs | 3 +- models/linear/src/model/config.rs | 8 +--- models/linear/src/model/linear.rs | 5 ++- models/linear/src/norm/layer.rs | 44 ++++++++++++++------ models/linear/src/norm/mod.rs | 6 +-- models/linear/src/params/mod.rs | 2 - models/linear/src/params/mode.rs | 3 +- models/linear/src/primitives.rs | 4 +- models/linear/tests/params.rs | 5 +-- models/transformers/src/attention/head.rs | 3 +- models/transformers/src/codec/model.rs | 19 ++++----- models/transformers/src/macros.rs | 29 ------------- models/transformers/src/params/store.rs | 2 +- 19 files changed, 121 insertions(+), 84 deletions(-) diff --git a/concision/benches/default.rs b/concision/benches/default.rs index 0dca58bc..9e07b340 100644 --- a/concision/benches/default.rs +++ b/concision/benches/default.rs @@ -17,14 +17,22 @@ fn fibonacci(b: &mut Bencher) { #[bench] fn iter_fibonacci(b: &mut Bencher) { - b.iter(|| fib::Fibonacci::new().take(BENCH_SIZE as usize).collect::>()) + b.iter(|| { + fib::Fibonacci::new() + .take(BENCH_SIZE as usize) + .collect::>() + }) } #[bench] fn recursive_fibonacci(b: &mut Bencher) { // exact code to benchmark must be passed as a closure to the iter // method of Bencher - b.iter(|| (0..BENCH_SIZE).map(fib::recursive_fibonacci).collect::>()) + b.iter(|| { + (0..BENCH_SIZE) + .map(fib::recursive_fibonacci) + .collect::>() + }) } mod fib { @@ -65,7 +73,7 @@ mod fib { _recursive_fibonacci(n - 1, current, current + previous) } } - + pub struct Fibonacci { curr: u32, next: u32, @@ -88,4 +96,4 @@ mod fib { Some(replace(&mut self.curr, new_curr)) } } -} \ No newline at end of file +} diff --git a/core/src/macros.rs b/core/src/macros.rs index 7498b0fd..5af8ac49 100644 --- a/core/src/macros.rs +++ b/core/src/macros.rs @@ -155,3 +155,46 @@ macro_rules! builder { } }; } + +#[macro_export] +macro_rules! getters { + ($($call:ident$(.$field:ident)?<$out:ty>),* $(,)?) => { + $($crate::getters!(@impl $call$(.$field)?<$out>);)* + }; + ($via:ident::<[$($call:ident$(.$field:ident)?<$out:ty>),* $(,)?]>) => { + $($crate::getters!(@impl $via::$call$(.$field)?<$out>);)* + }; + ($($call:ident$(.$field:ident)?),* $(,)? => $out:ty) => { + $($crate::getters!(@impl $call$(.$field)?<$out>);)* + }; + ($via:ident::<[$($call:ident$(.$field:ident)?),* $(,)?]> => $out:ty) => { + $crate::getters!($via::<[$($call$(.$field)?<$out>),*]>); + }; + + (@impl $call:ident<$out:ty>) => { + $crate::getters!(@impl $call.$call<$out>); + }; + (@impl $via:ident::$call:ident<$out:ty>) => { + $crate::getters!(@impl $via::$call.$call<$out>); + }; + (@impl $call:ident.$field:ident<$out:ty>) => { + pub fn $call(&self) -> &$out { + &self.$field + } + paste::paste! { + pub fn [< $call _mut>](&mut self) -> &mut $out { + &mut self.$field + } + } + }; + (@impl $via:ident::$call:ident.$field:ident<$out:ty>) => { + pub fn $call(&self) -> &$out { + &self.$via.$field + } + paste::paste! { + pub fn [< $call _mut>](&mut self) -> &mut $out { + &mut self.$via.$field + } + } + }; +} diff --git a/core/src/nn/model.rs b/core/src/nn/model.rs index c80c74ca..8991d08e 100644 --- a/core/src/nn/model.rs +++ b/core/src/nn/model.rs @@ -26,4 +26,3 @@ where fn context(&self) -> Self::Ctx; } - diff --git a/models/linear/src/impls/impl_rand.rs b/models/linear/src/impls/impl_rand.rs index 227a976d..683b59d8 100644 --- a/models/linear/src/impls/impl_rand.rs +++ b/models/linear/src/impls/impl_rand.rs @@ -17,7 +17,10 @@ where D: RemoveAxis, StandardNormal: Distribution, { - pub fn uniform(self) -> Self where K: 'static { + pub fn uniform(self) -> Self + where + K: 'static, + { let biased = self.is_biased(); Self { params: self.params.init_uniform(biased), diff --git a/models/linear/src/impls/model/impl_linear.rs b/models/linear/src/impls/model/impl_linear.rs index 762c8fea..e189c639 100644 --- a/models/linear/src/impls/model/impl_linear.rs +++ b/models/linear/src/impls/model/impl_linear.rs @@ -6,7 +6,10 @@ use crate::{Config, Linear, LinearParams, ParamMode}; use core::borrow::{Borrow, BorrowMut}; use nd::RemoveAxis; -impl Linear where K: ParamMode { +impl Linear +where + K: ParamMode, +{ pub fn from_features(inputs: usize, outputs: usize) -> Self where A: Clone + Default, diff --git a/models/linear/src/lib.rs b/models/linear/src/lib.rs index aa3a6b57..2b011f81 100644 --- a/models/linear/src/lib.rs +++ b/models/linear/src/lib.rs @@ -18,7 +18,7 @@ extern crate ndarray as nd; pub use self::model::{Config, Features, Layout, Linear}; pub use self::params::{mode::*, ParamsBase}; #[allow(unused_imports)] -pub use self::{traits::*, primitives::*, utils::*}; +pub use self::{primitives::*, traits::*, utils::*}; #[macro_use] pub(crate) mod macros; diff --git a/models/linear/src/mlp/mod.rs b/models/linear/src/mlp/mod.rs index 715d7681..c040a93a 100644 --- a/models/linear/src/mlp/mod.rs +++ b/models/linear/src/mlp/mod.rs @@ -18,8 +18,7 @@ pub trait MultiLayerPerceptron { type Output; } -pub trait Neuron { -} +pub trait Neuron {} pub trait Rho { type Output; diff --git a/models/linear/src/model/config.rs b/models/linear/src/model/config.rs index c2dce962..cc098318 100644 --- a/models/linear/src/model/config.rs +++ b/models/linear/src/model/config.rs @@ -9,7 +9,7 @@ use nd::{Dimension, IntoDimension, Ix2, RemoveAxis}; #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] -pub struct Config{ +pub struct Config { pub layout: Layout, pub name: String, _biased: PhantomData, @@ -150,11 +150,7 @@ where } } -impl concision::Config for Config -where - D: Dimension, -{ -} +impl concision::Config for Config where D: Dimension {} impl Default for Config where diff --git a/models/linear/src/model/linear.rs b/models/linear/src/model/linear.rs index 5291e6ee..1c24029c 100644 --- a/models/linear/src/model/linear.rs +++ b/models/linear/src/model/linear.rs @@ -89,7 +89,10 @@ where } } - pub fn is_biased(&self) -> bool where K: 'static { + pub fn is_biased(&self) -> bool + where + K: 'static, + { self.config().is_biased() } diff --git a/models/linear/src/norm/layer.rs b/models/linear/src/norm/layer.rs index 44f3726a..cd73c25e 100644 --- a/models/linear/src/norm/layer.rs +++ b/models/linear/src/norm/layer.rs @@ -3,10 +3,13 @@ Contrib: FL03 */ use crate::{LinearParams, ParamMode}; -use nd::RemoveAxis; use nd::prelude::*; +use nd::RemoveAxis; -pub struct LayerNorm where D: Dimension, { +pub struct LayerNorm +where + D: Dimension, +{ config: LayerNormConfig, params: LinearParams, } @@ -17,21 +20,32 @@ pub struct LayerNormConfig { } impl LayerNormConfig { - pub fn new() -> Self where D: Default { - Self { dim: D::default(), eps: 1e-5 } + pub fn new() -> Self + where + D: Default, + { + Self { + dim: D::default(), + eps: 1e-5, + } } - pub fn create(dim: D, eps: f64) -> Self where D: Default { + pub fn create(dim: D, eps: f64) -> Self + where + D: Default, + { Self { dim, eps } } pub fn with_dim(dim: D) -> Self { Self { dim, eps: 1e-5 } } - } -impl Default for LayerNormConfig where D: Default { +impl Default for LayerNormConfig +where + D: Default, +{ fn default() -> Self { Self { dim: D::default(), @@ -40,13 +54,19 @@ impl Default for LayerNormConfig where D: Default { } } - - -impl LayerNorm where D: RemoveAxis, K: ParamMode, { - pub fn from_shape(shape: Sh) -> Self where A: Default, Sh: ShapeBuilder { +impl LayerNorm +where + D: RemoveAxis, + K: ParamMode, +{ + pub fn from_shape(shape: Sh) -> Self + where + A: Default, + Sh: ShapeBuilder, + { let dim = shape.into_shape().raw_dim().clone(); let config = LayerNormConfig::with_dim(dim.clone()); let params = LinearParams::::default(dim); Self { config, params } } -} \ No newline at end of file +} diff --git a/models/linear/src/norm/mod.rs b/models/linear/src/norm/mod.rs index 5be2e4e9..febad096 100644 --- a/models/linear/src/norm/mod.rs +++ b/models/linear/src/norm/mod.rs @@ -3,12 +3,12 @@ Contrib: FL03 */ //! # Normalization -//! -//! +//! +//! pub use self::layer::LayerNorm; pub mod layer; pub(crate) mod prelude { pub use super::layer::LayerNorm; -} \ No newline at end of file +} diff --git a/models/linear/src/params/mod.rs b/models/linear/src/params/mod.rs index 15b3761e..c6fcedc5 100644 --- a/models/linear/src/params/mod.rs +++ b/models/linear/src/params/mod.rs @@ -15,8 +15,6 @@ pub mod mode; #[doc(inline)] pub use crate::primitives::params::*; - pub(crate) mod prelude { pub use super::mode::*; } - diff --git a/models/linear/src/params/mode.rs b/models/linear/src/params/mode.rs index e8ca29c0..9d7e7367 100644 --- a/models/linear/src/params/mode.rs +++ b/models/linear/src/params/mode.rs @@ -3,7 +3,6 @@ Contrib: FL03 */ - pub trait ParamMode: 'static { const BIASED: bool = false; @@ -15,7 +14,7 @@ pub trait ParamMode: 'static { } /* - ************* Implementations ************* + ************* Implementations ************* */ impl ParamMode for Option diff --git a/models/linear/src/primitives.rs b/models/linear/src/primitives.rs index aa1e8137..6cda38fd 100644 --- a/models/linear/src/primitives.rs +++ b/models/linear/src/primitives.rs @@ -14,7 +14,7 @@ pub(crate) type NodeBase = MaybePair, Array pub(crate) type Node = NodeBase, D, E>; pub(crate) mod params { - + macro_rules! params_ty { ($($name:ident<$repr:ident>),* $(,)?) => { $(params_ty!(@impl $name<$repr>);)* @@ -25,4 +25,4 @@ pub(crate) mod params { } params_ty!(LinearParams, LinearParamsShared,); -} \ No newline at end of file +} diff --git a/models/linear/tests/params.rs b/models/linear/tests/params.rs index 56af0d6a..50edc168 100644 --- a/models/linear/tests/params.rs +++ b/models/linear/tests/params.rs @@ -6,8 +6,8 @@ extern crate concision_core as concision; extern crate concision_linear as linear; use concision::Predict; -use linear::Features; use linear::params::{LinearParams, Param, Unbiased}; +use linear::Features; use core::str::FromStr; use ndarray::prelude::*; @@ -41,5 +41,4 @@ fn test_ndbuilders() { assert!(params.is_biased()); let params = LinearParams::::zeros(shape); assert!(!params.is_biased()); - -} \ No newline at end of file +} diff --git a/models/transformers/src/attention/head.rs b/models/transformers/src/attention/head.rs index 71494a18..1b201a8f 100644 --- a/models/transformers/src/attention/head.rs +++ b/models/transformers/src/attention/head.rs @@ -3,6 +3,7 @@ Contrib: FL03 */ use crate::params::ParamsBase; +use concision::getters; use nd::*; pub struct AttentionHead, D = Ix2> @@ -47,7 +48,7 @@ where &mut self.params } - access!(params::); + getters!(params::<[q, k, v]> => ArrayBase); ndbuilder!(new::default() where A: Default, S: DataOwned); ndbuilder!(ones() where A: Clone + num::One, S: DataOwned); ndbuilder!(zeros() where A: Clone + num::Zero, S: DataOwned); diff --git a/models/transformers/src/codec/model.rs b/models/transformers/src/codec/model.rs index 4c3cc180..3fb02b6a 100644 --- a/models/transformers/src/codec/model.rs +++ b/models/transformers/src/codec/model.rs @@ -3,6 +3,7 @@ Contrib: FL03 */ use super::{Decoder, Encoder}; +use concision::{builder, getters}; #[derive(Default)] pub struct Codec { @@ -16,20 +17,14 @@ impl Codec { CodecBuilder::new() } - pub fn context(&self) -> &Context { - &self.ctx - } - - pub fn decoder(&self) -> &Decoder { - &self.decoder - } - - pub fn encoder(&self) -> &Encoder { - &self.encoder - } + getters!( + context.ctx, + decoder, + encoder, + ); } -concision::builder!( +builder!( #[derive(Default)] CodecBuilder:: { ctx: Context, diff --git a/models/transformers/src/macros.rs b/models/transformers/src/macros.rs index 37e4f3e3..aabf8345 100644 --- a/models/transformers/src/macros.rs +++ b/models/transformers/src/macros.rs @@ -3,35 +3,6 @@ Contrib: FL03 */ -macro_rules! access { - ($($var:ident),* $(,)?) => { - $(access!(@impl $var);)* - }; - ($via:ident::<$($var:ident),* $(,)?>) => { - $(access!(@impl $via::$var);)* - }; - (@impl $var:ident) => { - pub fn $var(&self) -> &ArrayBase { - &self.$var - } - paste::paste! { - pub fn [< $var _mut>](&mut self) -> &mut ArrayBase { - &mut self.$var - } - } - }; - (@impl $via:ident::$var:ident) => { - pub fn $var(&self) -> &ArrayBase { - &self.$via.$var - } - paste::paste! { - pub fn [< $var _mut>](&mut self) -> &mut ArrayBase { - &mut self.$via.$var - } - } - }; -} - macro_rules! ndbuilder { ($method:ident$(::$call:ident)?() where $($rest:tt)*) => { ndbuilder!(@impl $method$(::$call)?() where $($rest)*); diff --git a/models/transformers/src/params/store.rs b/models/transformers/src/params/store.rs index 78f49bd3..1e316dab 100644 --- a/models/transformers/src/params/store.rs +++ b/models/transformers/src/params/store.rs @@ -38,7 +38,7 @@ where ndbuilder!(ones() where A: Clone + One, S: DataOwned); ndbuilder!(zeros() where A: Clone + Zero, S: DataOwned); - access!(q, k, v); + concision::getters!(q, k, v => ArrayBase); pub fn from_elem(shape: Sh, value: A) -> Self where From fd97e3df4b416c810c81a2725f47e0ace51abbd9 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sun, 12 May 2024 15:40:54 -0500 Subject: [PATCH 05/23] update Signed-off-by: Joe McCain III --- core/Cargo.toml | 1 + core/src/lib.rs | 2 + core/src/math/mod.rs | 15 ++++++ core/src/math/traits.rs | 77 +++++++++++++++++++++++++++++ core/src/traits/mod.rs | 4 +- core/src/traits/{math.rs => num.rs} | 35 ------------- models/linear/src/norm/layer.rs | 21 ++++++-- 7 files changed, 113 insertions(+), 42 deletions(-) create mode 100644 core/src/math/mod.rs create mode 100644 core/src/math/traits.rs rename core/src/traits/{math.rs => num.rs} (85%) diff --git a/core/Cargo.toml b/core/Cargo.toml index 0a847b31..b0891b64 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -116,6 +116,7 @@ required-features = ["approx"] [dependencies] ndarray.workspace = true num.workspace = true +paste.workspace = true scsys.workspace = true smart-default.workspace = true strum.workspace = true diff --git a/core/src/lib.rs b/core/src/lib.rs index b7ee82c7..b45a9a8f 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -24,6 +24,7 @@ pub(crate) mod primitives; pub mod error; pub mod func; +pub mod math; pub mod nn; pub mod ops; pub mod params; @@ -38,6 +39,7 @@ pub mod prelude { pub use super::error::prelude::*; pub use super::func::prelude::*; + pub use super::math::prelude::*; pub use super::nn::prelude::*; pub use super::ops::prelude::*; pub use super::params::prelude::*; diff --git a/core/src/math/mod.rs b/core/src/math/mod.rs new file mode 100644 index 00000000..bc9dc16f --- /dev/null +++ b/core/src/math/mod.rs @@ -0,0 +1,15 @@ +/* + Appellation: math + Contrib: FL03 +*/ +//! # Mathematics +//! +//! This module focuses on implementing various mathematical objects and operations that are +//! critical to the development of machine learning algorithms. +pub use self::traits::*; + +pub mod traits; + +pub(crate) mod prelude { + pub use super::traits::*; +} diff --git a/core/src/math/traits.rs b/core/src/math/traits.rs new file mode 100644 index 00000000..e817ecda --- /dev/null +++ b/core/src/math/traits.rs @@ -0,0 +1,77 @@ +/* + Appellation: traits + Contrib: FL03 +*/ +use nd::{Array, ArrayBase, Data, Dimension}; +use num::complex::{Complex, ComplexFloat}; + +macro_rules! unary { + ($($name:ident::$method:ident),*) => { + $(unary!(@impl $name::$method);)* + }; + (@impl $name:ident::$method:ident) => { + pub trait $name { + type Output; + + fn $method(self) -> Self::Output; + } + }; + (@fn $($method:ident),* $(,)?) => { + $(fn $method(self) -> Self::Output;)* + }; +} + +unary!(Abs::abs, SquareRoot::sqrt); + +/* + ********* Implementations ********* +*/ +macro_rules! fwd_unop { + ($name:ident::$method:ident<[$($T:ty),* $(,)?]>) => { + fwd_unop!($name::$method.$method<[$($T: $T),*]>); + }; + ($name:ident::$method:ident.$call:ident<[$($T:ty: $O:ty),* $(,)?]>) => { + $(fwd_unop!(@impl $name::$method.$call<$T> -> $O);)* + }; + (@impl $name:ident::$method:ident$(.$call:ident)?<$T:ty>) => { + fwd_unop!(@impl $name::$method$(.$call)?<$T> -> $T); + }; + (@impl $name:ident::$method:ident<$T:ty> -> $O:ty) => { + fwd_unop!(@impl $name::$method.$method<$T> -> $O); + }; + (@impl $name:ident::$method:ident.$call:ident<$T:ty> -> $O:ty) => { + impl $name for $T { + type Output = $O; + + fn $method(self) -> Self::Output { + <$T>::$call(self) + } + } + }; +} + +fwd_unop!(SquareRoot::sqrt<[f32, f64]>); + +impl SquareRoot for Complex +where + Complex: ComplexFloat, +{ + type Output = Self; + + fn sqrt(self) -> Self::Output { + ComplexFloat::sqrt(self) + } +} + +impl SquareRoot for ArrayBase +where + A: Clone + SquareRoot, + D: Dimension, + S: Data, +{ + type Output = Array; + + fn sqrt(self) -> Self::Output { + self.mapv(|x| x.sqrt()) + } +} diff --git a/core/src/traits/mod.rs b/core/src/traits/mod.rs index 6663f98b..c8e1a33d 100644 --- a/core/src/traits/mod.rs +++ b/core/src/traits/mod.rs @@ -5,7 +5,7 @@ pub use self::prelude::*; pub mod generator; -pub mod math; +pub mod num; pub mod ops; pub mod predict; pub mod train; @@ -44,8 +44,8 @@ pub(crate) mod misc { pub(crate) mod prelude { pub use super::arr::prelude::*; pub use super::generator::*; - pub use super::math::*; pub use super::misc::prelude::*; + pub use super::num::*; pub use super::ops::*; pub use super::predict::*; pub use super::train::*; diff --git a/core/src/traits/math.rs b/core/src/traits/num.rs similarity index 85% rename from core/src/traits/math.rs rename to core/src/traits/num.rs index 1d499a72..d21e77f4 100644 --- a/core/src/traits/math.rs +++ b/core/src/traits/num.rs @@ -52,10 +52,6 @@ pub trait RoundTo { fn round_to(&self, places: usize) -> Self; } -pub trait SquareRoot { - fn sqrt(self) -> Self; -} - /* ********* Implementations ********* */ @@ -151,34 +147,3 @@ where crate::round_to(*self, places) } } - -impl SquareRoot for f32 { - fn sqrt(self) -> Self { - f32::sqrt(self) - } -} - -impl SquareRoot for f64 { - fn sqrt(self) -> Self { - f64::sqrt(self) - } -} - -impl SquareRoot for Complex -where - T: Float, -{ - fn sqrt(self) -> Self { - Complex::::sqrt(self) - } -} - -impl SquareRoot for Array -where - D: Dimension, - T: Float, -{ - fn sqrt(self) -> Self { - self.mapv(|x| x.sqrt()) - } -} diff --git a/models/linear/src/norm/layer.rs b/models/linear/src/norm/layer.rs index cd73c25e..182563cc 100644 --- a/models/linear/src/norm/layer.rs +++ b/models/linear/src/norm/layer.rs @@ -6,6 +6,10 @@ use crate::{LinearParams, ParamMode}; use nd::prelude::*; use nd::RemoveAxis; +// #62 +/// [LayerNorm] adhears to the [Layer Normalization](https://arxiv.org/abs/1607.06450) algorithm. +/// +/// ### Resources pub struct LayerNorm where D: Dimension, @@ -19,11 +23,11 @@ pub struct LayerNormConfig { pub eps: f64, } -impl LayerNormConfig { - pub fn new() -> Self - where - D: Default, - { +impl LayerNormConfig +where + D: Dimension, +{ + pub fn new() -> Self { Self { dim: D::default(), eps: 1e-5, @@ -69,4 +73,11 @@ where let params = LinearParams::::default(dim); Self { config, params } } + + pub fn config(&self) -> &LayerNormConfig { + &self.config + } + + concision::getters!(params => LinearParams); + } From 21ad106455085104f7e2db431538c0b54957af93 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sun, 12 May 2024 15:48:00 -0500 Subject: [PATCH 06/23] update Signed-off-by: Joe McCain III --- models/linear/src/norm/layer.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/models/linear/src/norm/layer.rs b/models/linear/src/norm/layer.rs index 182563cc..0149bf9e 100644 --- a/models/linear/src/norm/layer.rs +++ b/models/linear/src/norm/layer.rs @@ -78,6 +78,12 @@ where &self.config } - concision::getters!(params => LinearParams); + pub fn params(&self) -> &LinearParams { + &self.params + } + + pub fn params_mut(&mut self) -> &mut LinearParams { + &mut self.params + } } From 21cd81d0eb094e0851882452f6970acf08c45bec Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Mon, 13 May 2024 15:24:53 -0500 Subject: [PATCH 07/23] update Signed-off-by: Joe McCain III --- core/src/rand/initialize.rs | 15 +++- models/linear/Cargo.toml | 4 +- models/linear/src/impls/impl_rand.rs | 77 ++++++++++--------- models/linear/src/impls/params/impl_params.rs | 13 ++-- models/linear/src/model/config.rs | 70 +++++++++-------- models/linear/src/model/layout/layout.rs | 4 +- models/linear/src/model/linear.rs | 16 +++- models/linear/src/norm/layer.rs | 4 +- models/linear/src/norm/mod.rs | 2 + models/linear/src/params/entry.rs | 4 +- models/linear/src/params/params.rs | 24 ++++-- models/linear/src/utils.rs | 8 +- models/linear/tests/model.rs | 6 +- models/linear/tests/params.rs | 34 ++++---- 14 files changed, 171 insertions(+), 110 deletions(-) diff --git a/core/src/rand/initialize.rs b/core/src/rand/initialize.rs index f625fac1..e3948967 100644 --- a/core/src/rand/initialize.rs +++ b/core/src/rand/initialize.rs @@ -2,6 +2,7 @@ Appellation: initialize Contrib: FL03 */ +use core::ops::Neg; use nd::{ArrayBase, DataOwned, Dimension, RawData, ShapeBuilder}; use ndrand::RandomExt; use num::traits::Float; @@ -69,6 +70,16 @@ where &mut rngs::StdRng::seed_from_u64(seed), ) } + + /// A [uniform](rand_distr::uniform::Uniform) generator with values between u(-dk, dk) + fn uniform(shape: Sh, dk: A) -> ArrayBase + where + A: Clone + Neg + SampleUniform, + S: DataOwned, + Sh: ShapeBuilder, + { + Self::genrand(shape, Uniform::new(dk.clone().neg(), dk)) + } /// Generate a random array with values between u(-a, a) where a is the reciprocal of the value at the given axis fn uniform_along(shape: Sh, axis: usize) -> ArrayBase where @@ -78,10 +89,10 @@ where { let dim = shape.into_shape().raw_dim().clone(); let dk = A::from(dim[axis]).unwrap().recip(); - Self::uniform(dim, -dk, dk) + Self::uniform(dim, dk) } /// A [uniform](rand_distr::uniform::Uniform) generator with values between u(-dk, dk) - fn uniform(shape: Sh, a: A, b: A) -> ArrayBase + fn uniform_between(shape: Sh, a: A, b: A) -> ArrayBase where A: SampleUniform, S: DataOwned, diff --git a/models/linear/Cargo.toml b/models/linear/Cargo.toml index 50ee7411..ee925322 100644 --- a/models/linear/Cargo.toml +++ b/models/linear/Cargo.toml @@ -88,11 +88,11 @@ test = true [[test]] name = "model" -required-features = ["rand"] +required-features = ["std"] [[test]] name = "params" -required-features = ["rand"] +required-features = ["std"] [build-dependencies] diff --git a/models/linear/src/impls/impl_rand.rs b/models/linear/src/impls/impl_rand.rs index 683b59d8..e1a6fc6b 100644 --- a/models/linear/src/impls/impl_rand.rs +++ b/models/linear/src/impls/impl_rand.rs @@ -6,7 +6,7 @@ use crate::params::ParamsBase; use crate::{bias_dim, Linear}; -use concision::prelude::GenerateRandom; +use concision::prelude::InitializeExt; use concision::rand::rand_distr::{uniform, Distribution, StandardNormal}; use nd::*; use num::Float; @@ -15,15 +15,19 @@ impl Linear where A: Float + uniform::SampleUniform, D: RemoveAxis, + K: 'static, StandardNormal: Distribution, { - pub fn uniform(self) -> Self - where - K: 'static, - { - let biased = self.is_biased(); + pub fn uniform(self) -> Self { Self { - params: self.params.init_uniform(biased), + params: self.params.uniform(), + ..self + } + } + + pub fn uniform_between(self, low: A, high: A) -> Self { + Self { + params: self.params.uniform_between(low, high), ..self } } @@ -33,43 +37,46 @@ impl ParamsBase, D, K> where A: Float + uniform::SampleUniform, D: RemoveAxis, + K: 'static, StandardNormal: Distribution, { + /// Computes the reciprocal of the input features. pub(crate) fn dk(&self) -> A { - A::from(self.in_features()).unwrap().recip().sqrt() + A::from(self.in_features()).unwrap().recip() } - - pub fn init_uniform(mut self, biased: bool) -> Self { - if biased { - self = self.init_bias(); - } - self.init_weight() - } - - pub fn init_bias(mut self) -> Self { - let dim = bias_dim(self.raw_dim()); - self.bias = Some(Array::uniform_between(self.dk(), dim)); - self + /// Computes the square root of the reciprical of the input features. + pub(crate) fn dk_sqrt(&self) -> A { + self.dk().sqrt() } - pub fn init_weight(mut self) -> Self { - self.weights = Array::uniform_between(self.dk(), self.raw_dim()); - self + pub fn uniform(self) -> Self { + let dk = self.dk_sqrt(); + self.uniform_between(-dk, dk) } - pub fn uniform(self) -> Self { - let dk = self.dk(); - let bias = if self.is_biased() { - let dim = bias_dim(self.raw_dim()); - Some(Array::uniform_between(dk, dim)) + pub fn uniform_between(self, low: A, high: A) -> Self { + if self.is_biased() && !self.bias.is_some() { + let b_dim = bias_dim(self.raw_dim()); + Self { + bias: Some(Array::uniform_between(b_dim, low, high)), + weights: Array::uniform_between(self.raw_dim(), low, high), + _mode: self._mode, + } + } else if !self.is_biased() && self.bias.is_some() { + Self { + bias: None, + weights: Array::uniform_between(self.raw_dim(), low, high), + _mode: self._mode, + } } else { - None - }; - let weights = Array::uniform_between(dk, self.raw_dim()); - Self { - bias, - weights, - _mode: self._mode, + Self { + bias: self + .bias + .as_ref() + .map(|b| Array::uniform_between(b.raw_dim(), low, high)), + weights: Array::uniform_between(self.raw_dim(), low, high), + _mode: self._mode, + } } } } diff --git a/models/linear/src/impls/params/impl_params.rs b/models/linear/src/impls/params/impl_params.rs index dc555d46..ad667f00 100644 --- a/models/linear/src/impls/params/impl_params.rs +++ b/models/linear/src/impls/params/impl_params.rs @@ -79,29 +79,26 @@ where } } -impl PartialEq<(ArrayBase, Option>)> for ParamsBase +impl PartialEq<(ArrayBase, Option>)> + for ParamsBase where A: PartialEq, D: RemoveAxis, S: Data, { fn eq(&self, (weights, bias): &(ArrayBase, Option>)) -> bool { - self.weights == weights && self.bias == *bias + self.weights() == weights && self.bias() == bias.as_ref() } } -impl PartialEq<(ArrayBase, ArrayBase)> for ParamsBase +impl PartialEq<(ArrayBase, ArrayBase)> for ParamsBase where A: PartialEq, D: RemoveAxis, S: Data, { fn eq(&self, (weights, bias): &(ArrayBase, ArrayBase)) -> bool { - let mut cmp = self.weights == weights; - if let Some(b) = &self.bias { - cmp &= b == bias; - } - cmp + self.weights() == weights && self.bias() == Some(bias) } } diff --git a/models/linear/src/model/config.rs b/models/linear/src/model/config.rs index cc098318..1c558441 100644 --- a/models/linear/src/model/config.rs +++ b/models/linear/src/model/config.rs @@ -5,14 +5,15 @@ use super::layout::{Features, Layout}; use crate::params::{Biased, Unbiased}; use core::marker::PhantomData; -use nd::{Dimension, IntoDimension, Ix2, RemoveAxis}; +use nd::prelude::*; +use nd::{IntoDimension, RemoveAxis, ShapeError}; #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] -pub struct Config { +pub struct Config { pub layout: Layout, pub name: String, - _biased: PhantomData, + _biased: PhantomData, } impl Config @@ -23,23 +24,41 @@ where Self { layout: Layout::default(), name: String::new(), - _biased: PhantomData, + _biased: PhantomData::, } } + pub fn from_dim(dim: D) -> Result + where + D: Dimension, + { + let layout = Layout::from_dim(dim)?; + let res = Self::new().with_layout(layout); + Ok(res) + } + + pub fn from_shape(shape: Sh) -> Self + where + D: RemoveAxis, + Sh: ShapeBuilder, + { + let layout = Layout::from_shape(shape); + Self::new().with_layout(layout) + } + pub fn into_biased(self) -> Config { Config { - _biased: PhantomData, layout: self.layout, name: self.name, + _biased: PhantomData::, } } pub fn into_unbiased(self) -> Config { Config { - _biased: PhantomData, layout: self.layout, name: self.name, + _biased: PhantomData::, } } @@ -60,15 +79,15 @@ where _biased: self._biased, } } - + /// Returns a cloned reference to the [dimension](ndarray::Dimension) of the [layout](Layout) pub fn dim(&self) -> D { - self.layout.dim().clone() + self.layout().dim() } pub fn into_pattern(self) -> D::Pattern { self.dim().into_pattern() } - + /// This function attempts to convert the [layout](Layout) of the [Config] into a new [dimension](ndarray::Dimension) pub fn into_dimensionality(self, dim: E) -> Result, nd::ShapeError> where E: Dimension, @@ -80,7 +99,8 @@ where }; Ok(tmp) } - + /// Determine whether the [Config] is [Biased]; + /// Returns true by comparing the [TypeId](core::any::TypeId) of `K` against the [TypeId](core::any::TypeId) of the [Biased] type pub fn is_biased(&self) -> bool where K: 'static, @@ -89,21 +109,21 @@ where TypeId::of::() == TypeId::of::() } - + /// Returns an instance to the [Features] of the [Layout] pub fn features(&self) -> Features { - self.layout.features() + self.layout().features() } - - pub fn layout(&self) -> &Layout { + /// Returns an owned reference to the [Layout] + pub const fn layout(&self) -> &Layout { &self.layout } - + /// Returns an immutable reference to the `name` of the model. pub fn name(&self) -> &str { &self.name } pub fn ndim(&self) -> usize { - self.layout.ndim() + self.layout().ndim() } } @@ -112,7 +132,7 @@ impl Config { Self { layout: Layout::new((outputs, inputs).into_dimension()), name: String::new(), - _biased: PhantomData, + _biased: PhantomData::, } } } @@ -121,17 +141,10 @@ impl Config where D: Dimension, { + /// The default constructor method for building [Biased] configurations. pub fn biased() -> Self { Self::new() } - - pub fn from_dim_biased(dim: D) -> Self - where - D: RemoveAxis, - { - let layout = Layout::from_dim(dim).unwrap(); - Self::new().with_layout(layout) - } } impl Config @@ -141,13 +154,6 @@ where pub fn unbiased() -> Self { Self::new() } - - pub fn from_dim(dim: D) -> Config - where - D: RemoveAxis, - { - Config::::new().with_layout(Layout::from_dim(dim).unwrap()) - } } impl concision::Config for Config where D: Dimension {} diff --git a/models/linear/src/model/layout/layout.rs b/models/linear/src/model/layout/layout.rs index a2a4d83b..47a1ef8f 100644 --- a/models/linear/src/model/layout/layout.rs +++ b/models/linear/src/model/layout/layout.rs @@ -49,8 +49,8 @@ where self.dim.slice_mut() } - pub fn dim(&self) -> &D { - &self.dim + pub fn dim(&self) -> D { + self.dim.clone() } pub fn features(&self) -> Features { diff --git a/models/linear/src/model/linear.rs b/models/linear/src/model/linear.rs index 1c24029c..06b8798c 100644 --- a/models/linear/src/model/linear.rs +++ b/models/linear/src/model/linear.rs @@ -3,14 +3,14 @@ Contrib: FL03 */ use super::{Config, Layout}; -use crate::{Biased, LinearParams, ParamMode}; +use crate::{Biased, LinearParams, ParamMode, Unbiased}; use concision::prelude::{Predict, Result}; use nd::{Array, Dimension, Ix2, RemoveAxis}; /// An implementation of a linear model. /// /// In an effort to streamline the api, the [Linear] model relies upon a [ParamMode] type ([Biased] or [Unbiased](crate::params::mode::Unbiased)) -/// which enables the model to automatically determine whether or not to include a bias term. Doing so enables us to forward many methods +/// which enables the model to automatically determine whether or not to include a bias term. Doing so allows the model to inherit several methods /// familar to the underlying [ndarray](https://docs.rs/ndarray) crate. pub struct Linear where @@ -82,6 +82,7 @@ where pub fn into_biased(self) -> Linear where A: Default, + K: 'static, { Linear { config: self.config.into_biased(), @@ -89,6 +90,17 @@ where } } + pub fn into_unbiased(self) -> Linear + where + A: Default, + K: 'static, + { + Linear { + config: self.config.into_unbiased(), + params: self.params.into_unbiased(), + } + } + pub fn is_biased(&self) -> bool where K: 'static, diff --git a/models/linear/src/norm/layer.rs b/models/linear/src/norm/layer.rs index 0149bf9e..4f250a87 100644 --- a/models/linear/src/norm/layer.rs +++ b/models/linear/src/norm/layer.rs @@ -2,6 +2,7 @@ Appellation: layer Contrib: FL03 */ +use super::EPSILON; use crate::{LinearParams, ParamMode}; use nd::prelude::*; use nd::RemoveAxis; @@ -53,7 +54,7 @@ where fn default() -> Self { Self { dim: D::default(), - eps: 1e-5, + eps: EPSILON, } } } @@ -85,5 +86,4 @@ where pub fn params_mut(&mut self) -> &mut LinearParams { &mut self.params } - } diff --git a/models/linear/src/norm/mod.rs b/models/linear/src/norm/mod.rs index febad096..67fbca27 100644 --- a/models/linear/src/norm/mod.rs +++ b/models/linear/src/norm/mod.rs @@ -9,6 +9,8 @@ pub use self::layer::LayerNorm; pub mod layer; +pub const EPSILON: f64 = 1e-5; + pub(crate) mod prelude { pub use super::layer::LayerNorm; } diff --git a/models/linear/src/params/entry.rs b/models/linear/src/params/entry.rs index 2a193361..7e83bd45 100644 --- a/models/linear/src/params/entry.rs +++ b/models/linear/src/params/entry.rs @@ -48,11 +48,11 @@ where D: RemoveAxis, S: RawData, { - pub fn bias(data: ArrayBase) -> Self { + pub fn from_bias(data: ArrayBase) -> Self { Self::Bias(data) } - pub fn weight(data: ArrayBase) -> Self { + pub fn from_weight(data: ArrayBase) -> Self { Self::Weight(data) } } diff --git a/models/linear/src/params/params.rs b/models/linear/src/params/params.rs index 64386695..5dfd1b74 100644 --- a/models/linear/src/params/params.rs +++ b/models/linear/src/params/params.rs @@ -30,13 +30,21 @@ where pub fn into_biased(self) -> ParamsBase where A: Default, + K: 'static, S: DataOwned, { + if self.is_biased() { + return ParamsBase { + bias: self.bias, + weights: self.weights, + _mode: PhantomData::, + }; + } let sm = crate::bias_dim(self.raw_dim()); ParamsBase { bias: Some(ArrayBase::default(sm)), weights: self.weights, - _mode: PhantomData, + _mode: PhantomData::, } } @@ -44,7 +52,7 @@ where ParamsBase { bias: None, weights: self.weights, - _mode: PhantomData, + _mode: PhantomData::, } } @@ -71,9 +79,14 @@ where pub fn in_features(&self) -> usize { self.features().dmodel() } - - pub fn is_biased(&self) -> bool { - self.bias().is_some() + /// Returns true if the parameter store is biased; + /// Compares the [TypeId](core::any::TypeId) of the store with the [Biased](crate::Biased) type. + pub fn is_biased(&self) -> bool + where + K: 'static, + { + use core::any::TypeId; + TypeId::of::() == TypeId::of::() } pub fn ndim(&self) -> usize { @@ -149,6 +162,7 @@ where } impl ParamsBase where + K: 'static, S: RawData, { pub fn set_node(&mut self, idx: usize, node: Node) diff --git a/models/linear/src/utils.rs b/models/linear/src/utils.rs index abd8fb77..445e0862 100644 --- a/models/linear/src/utils.rs +++ b/models/linear/src/utils.rs @@ -2,7 +2,9 @@ Appellation: utils Contrib: FL03 */ -use nd::*; +use crate::params::Biased; +use core::any::TypeId; +use nd::{ArrayBase, Axis, Dimension, RawData, RemoveAxis}; /// A utilitarian funciton for building bias tensors. pub(crate) fn build_bias(biased: bool, dim: D, builder: F) -> Option> @@ -30,3 +32,7 @@ where dim.remove_axis(Axis(1)) } } + +pub fn is_biased() -> bool { + TypeId::of::() == TypeId::of::() +} diff --git a/models/linear/tests/model.rs b/models/linear/tests/model.rs index 70c6f8b4..165a75b2 100644 --- a/models/linear/tests/model.rs +++ b/models/linear/tests/model.rs @@ -23,13 +23,14 @@ lazy_static! { #[test] fn test_config() { let dim = FEATURES.clone().into_dimension(); - let config = Config::from_dim_biased(dim); + let config = Config::::from_shape(dim); assert!(config.is_biased()); - let config = Config::from_dim(dim); + let config = Config::::from_shape(dim); assert!(!config.is_biased()); } #[test] +#[cfg(feature = "rand")] fn test_linear() { let (samples, (outputs, inputs)) = SHAPE; @@ -42,6 +43,7 @@ fn test_linear() { } #[test] +#[cfg(feature = "rand")] fn test_model_modes() { let (_samples, (outputs, inputs)) = SHAPE; diff --git a/models/linear/tests/params.rs b/models/linear/tests/params.rs index 50edc168..aaa474a9 100644 --- a/models/linear/tests/params.rs +++ b/models/linear/tests/params.rs @@ -2,19 +2,19 @@ Appellation: params Contrib: FL03 */ +#![allow(unused_imports)] extern crate concision_core as concision; extern crate concision_linear as linear; use concision::Predict; +use core::str::FromStr; use linear::params::{LinearParams, Param, Unbiased}; use linear::Features; - -use core::str::FromStr; use ndarray::prelude::*; const SAMPLES: usize = 20; -const INPUTS: usize = 5; -const DMODEL: usize = 3; +const D_MODEL: usize = 5; +const FEATURES: usize = 3; #[test] fn test_keys() { @@ -25,20 +25,24 @@ fn test_keys() { } #[test] +fn test_builders() { + let shape = (D_MODEL, FEATURES); + let params = LinearParams::::ones(shape); + assert!(params.is_biased()); + assert_eq!(params.weights(), &Array2::ones(shape)); + assert_eq!(params.bias().unwrap(), &Array1::ones(D_MODEL)); + let params = LinearParams::::zeros(shape); + assert!(!params.is_biased()); + assert_eq!(params.weights(), &Array2::zeros(shape)); +} + +#[test] +#[cfg(feature = "rand")] fn test_linear_params() { - let (samples, inputs, outputs) = (SAMPLES, INPUTS, DMODEL); + let (samples, inputs, outputs) = (SAMPLES, D_MODEL, FEATURES); let features = Features::new(outputs, inputs); - let data = Array2::::zeros((samples, inputs)); + let data = Array2::::ones((samples, inputs)); let params = LinearParams::biased(features).uniform(); let y: Array2 = params.predict(&data).unwrap(); assert_eq!(y.dim(), (samples, outputs)); } - -#[test] -fn test_ndbuilders() { - let shape = (300, 10); - let params = LinearParams::::ones(shape); - assert!(params.is_biased()); - let params = LinearParams::::zeros(shape); - assert!(!params.is_biased()); -} From 966d0a6d52433e8cef02da5467ecb34fa74d911d Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Mon, 13 May 2024 23:53:54 -0500 Subject: [PATCH 08/23] update Signed-off-by: Joe McCain III --- core/src/macros.rs | 41 +++++++++++++++++++ models/linear/src/impls/params/impl_from.rs | 3 +- models/linear/src/impls/params/impl_params.rs | 39 ++++-------------- models/linear/src/impls/params/impl_serde.rs | 6 +-- models/linear/src/model/linear.rs | 4 +- .../linear/src/params/{entry.rs => item.rs} | 4 +- models/linear/src/params/mod.rs | 8 ++-- .../linear/src/params/{params.rs => store.rs} | 16 ++++---- models/linear/tests/params.rs | 2 +- models/transformers/src/params/store.rs | 11 +++-- 10 files changed, 76 insertions(+), 58 deletions(-) rename models/linear/src/params/{entry.rs => item.rs} (96%) rename models/linear/src/params/{params.rs => store.rs} (96%) diff --git a/core/src/macros.rs b/core/src/macros.rs index 5af8ac49..539e7a26 100644 --- a/core/src/macros.rs +++ b/core/src/macros.rs @@ -198,3 +198,44 @@ macro_rules! getters { } }; } + +/// AS +#[macro_export] +macro_rules! dimensional { + + ($name:ident$(())?) => { + pub fn dim(&self) -> D::Pattern { + self.$name$(())?.dim() + } + + pub fn ndim(&self) -> usize { + self.$name$(())?.ndim() + } + + pub fn raw_dim(&self) -> D { + self.$name$(())?.dim() + } + /// Returns a reference to the current dimension, as a slice. + pub fn shape(&self) -> &[usize] { + self.$name$(())?.shape() + } + }; + (dim: $name:ident$(())?) => { + /// Returns a reference to the current dimension, as a slice. + pub fn as_slice(&self) -> &[usize] { + self.$name$(())?.shape() + } + + pub fn into_pattern(self) -> D::Pattern { + self.$name$(())?.into_pattern() + } + + pub fn ndim(&self) -> usize { + self.$name$(())?.ndim() + } + + pub fn raw_dim(&self) -> D { + self.$name$(())?.dim().clone() + } + }; +} diff --git a/models/linear/src/impls/params/impl_from.rs b/models/linear/src/impls/params/impl_from.rs index 5564d08b..18be8c99 100644 --- a/models/linear/src/impls/params/impl_from.rs +++ b/models/linear/src/impls/params/impl_from.rs @@ -23,10 +23,9 @@ where fn into_iter(self) -> Self::IntoIter { let axis = Axis(0); - let bias = self.bias().unwrap(); self.weights() .axis_iter(axis) - .zip(bias.axis_iter(axis)) + .zip(self.bias().axis_iter(axis)) .map(|(w, b)| (w.to_owned(), b.to_owned())) .collect::>() .into_iter() diff --git a/models/linear/src/impls/params/impl_params.rs b/models/linear/src/impls/params/impl_params.rs index ad667f00..45919c9d 100644 --- a/models/linear/src/impls/params/impl_params.rs +++ b/models/linear/src/impls/params/impl_params.rs @@ -24,26 +24,6 @@ where } } -impl<'a, A, B, T, S, D, K> Predict for &'a ParamsBase -where - A: Dot, Output = B>, - B: Add<&'a ArrayBase, Output = B>, - D: RemoveAxis, - S: Data, - T: NdFloat, -{ - type Output = B; - - fn predict(&self, input: &A) -> Result { - let wt = self.weights().t().to_owned(); - let mut res = input.dot(&wt); - if let Some(bias) = self.bias() { - res = res + bias; - } - Ok(res) - } -} - impl Clone for ParamsBase where A: Clone, @@ -75,7 +55,7 @@ where S: Data, { fn eq(&self, other: &Self) -> bool { - self.weights == other.weights && self.bias == other.bias + self.weights() == other.weights && self.bias == other.bias } } @@ -87,7 +67,7 @@ where S: Data, { fn eq(&self, (weights, bias): &(ArrayBase, Option>)) -> bool { - self.weights() == weights && self.bias() == bias.as_ref() + self.weights() == weights && self.bias.as_ref() == bias.as_ref() } } @@ -98,13 +78,13 @@ where S: Data, { fn eq(&self, (weights, bias): &(ArrayBase, ArrayBase)) -> bool { - self.weights() == weights && self.bias() == Some(bias) + self.weights() == weights && self.bias.as_ref() == Some(bias) } } macro_rules! impl_predict { - ($( $($lt:lifetime)? $name:ident),* $(,)?) => { - $(impl_predict!(@impl $($lt)? $name);)* + ($($name:ident),* $(,)?) => { + $(impl_predict!(@impl $name);)* }; (@impl $name:ident) => { impl Predict for $name @@ -120,15 +100,14 @@ macro_rules! impl_predict { fn predict(&self, input: &A) -> Result { let wt = self.weights().t().to_owned(); let mut res = input.dot(&wt); - if let Some(bias) = self.bias() { + if let Some(bias) = self.bias.as_ref() { res = res + bias; } Ok(res) } } - }; - (@impl $name:ident<&'a $rhs:ident>) => { - impl<'a, A, B, T, S, D, K> Predict<&'a $rhs> for $name + + impl<'a, A, B, T, S, D, K> Predict for &'a $name where A: Dot, Output = B>, B: Add<&'a ArrayBase, Output = B>, @@ -141,7 +120,7 @@ macro_rules! impl_predict { fn predict(&self, input: &A) -> Result { let wt = self.weights().t().to_owned(); let mut res = input.dot(&wt); - if let Some(bias) = self.bias() { + if let Some(bias) = self.bias.as_ref() { res = res + bias; } Ok(res) diff --git a/models/linear/src/impls/params/impl_serde.rs b/models/linear/src/impls/params/impl_serde.rs index 2d8641fa..6188228b 100644 --- a/models/linear/src/impls/params/impl_serde.rs +++ b/models/linear/src/impls/params/impl_serde.rs @@ -4,7 +4,7 @@ */ #![cfg(feature = "serde")] -use crate::params::{Entry, ParamsBase}; +use crate::params::{Parameter, ParamsBase}; use core::marker::PhantomData; use nd::*; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -40,11 +40,11 @@ where where Ser: Serializer, { - (self.bias(), self.weights()).serialize(serializer) + (self.bias.as_ref(), self.weights()).serialize(serializer) } } -impl Serialize for Entry +impl Serialize for Parameter where A: Serialize, S: Data, diff --git a/models/linear/src/model/linear.rs b/models/linear/src/model/linear.rs index 06b8798c..43144147 100644 --- a/models/linear/src/model/linear.rs +++ b/models/linear/src/model/linear.rs @@ -121,10 +121,10 @@ where D: RemoveAxis, { pub fn bias(&self) -> &Array { - self.params.bias().unwrap() + self.params().bias() } pub fn bias_mut(&mut self) -> &mut Array { - self.params.bias_mut().unwrap() + self.params_mut().bias_mut() } } diff --git a/models/linear/src/params/entry.rs b/models/linear/src/params/item.rs similarity index 96% rename from models/linear/src/params/entry.rs rename to models/linear/src/params/item.rs index 7e83bd45..9ddbf28b 100644 --- a/models/linear/src/params/entry.rs +++ b/models/linear/src/params/item.rs @@ -34,7 +34,7 @@ use strum::{AsRefStr, EnumDiscriminants, EnumIs, VariantNames}; ), strum(serialize_all = "lowercase") )] -pub enum Entry +pub enum Parameter where S: RawData, D: RemoveAxis, @@ -43,7 +43,7 @@ where Weight(ArrayBase), } -impl Entry +impl Parameter where D: RemoveAxis, S: RawData, diff --git a/models/linear/src/params/mod.rs b/models/linear/src/params/mod.rs index c6fcedc5..571bdecd 100644 --- a/models/linear/src/params/mod.rs +++ b/models/linear/src/params/mod.rs @@ -3,13 +3,13 @@ Contrib: FL03 */ #[doc(inline)] -pub use self::entry::{Entry, Param}; +pub use self::item::{Param, Parameter}; pub use self::mode::*; -pub use self::params::ParamsBase; +pub use self::store::*; -mod params; +mod store; -pub mod entry; +pub mod item; pub mod mode; #[doc(inline)] diff --git a/models/linear/src/params/params.rs b/models/linear/src/params/store.rs similarity index 96% rename from models/linear/src/params/params.rs rename to models/linear/src/params/store.rs index 5dfd1b74..b278d268 100644 --- a/models/linear/src/params/params.rs +++ b/models/linear/src/params/store.rs @@ -56,14 +56,6 @@ where } } - pub fn bias(&self) -> Option<&ArrayBase> { - self.bias.as_ref() - } - - pub fn bias_mut(&mut self) -> Option<&mut ArrayBase> { - self.bias.as_mut() - } - pub const fn weights(&self) -> &ArrayBase { &self.weights } @@ -139,6 +131,14 @@ where _mode: PhantomData::, } } + + pub fn bias(&self) -> &ArrayBase { + self.bias.as_ref().unwrap() + } + + pub fn bias_mut(&mut self) -> &mut ArrayBase { + self.bias.as_mut().unwrap() + } } impl ParamsBase diff --git a/models/linear/tests/params.rs b/models/linear/tests/params.rs index aaa474a9..fed65fb7 100644 --- a/models/linear/tests/params.rs +++ b/models/linear/tests/params.rs @@ -30,7 +30,7 @@ fn test_builders() { let params = LinearParams::::ones(shape); assert!(params.is_biased()); assert_eq!(params.weights(), &Array2::ones(shape)); - assert_eq!(params.bias().unwrap(), &Array1::ones(D_MODEL)); + assert_eq!(params.bias(), &Array1::ones(D_MODEL)); let params = LinearParams::::zeros(shape); assert!(!params.is_biased()); assert_eq!(params.weights(), &Array2::zeros(shape)); diff --git a/models/transformers/src/params/store.rs b/models/transformers/src/params/store.rs index 1e316dab..4a3eda39 100644 --- a/models/transformers/src/params/store.rs +++ b/models/transformers/src/params/store.rs @@ -60,29 +60,28 @@ where { (self.q.view(), self.k.view(), self.v.view()) } - + /// Consumes the current parameters, returning a three-tuple the Q, K, and V matrices respectivley. pub fn into_qkv(self) -> (ArrayBase, ArrayBase, ArrayBase) where S: DataOwned, { (self.q, self.k, self.v) } - /// Return the [pattern](ndarray::Dimension::Pattern) of the dimension pub fn dim(&self) -> D::Pattern { - self.q.dim() + self.q().dim() } /// Get the rank of the parameters; i.e. the number of dimensions. pub fn rank(&self) -> usize { - self.q.ndim() + self.q().ndim() } /// Returns the raw dimension ([D](ndarray::Dimension)) of the parameters. pub fn raw_dim(&self) -> D { - self.q.raw_dim() + self.q().raw_dim() } /// Returns a slice of the current shape of the parameters. pub fn shape(&self) -> &[usize] { - self.q.shape() + self.q().shape() } ndview!(into_owned::(self) where A: Clone, S: Data); ndview!(to_owned::(&self) where A: Clone, S: Data); From bbbbb4cbff5eb66ad893b0a2908ed0686475718a Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Tue, 14 May 2024 00:06:24 -0500 Subject: [PATCH 09/23] update Signed-off-by: Joe McCain III --- models/linear/src/params/mode.rs | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/models/linear/src/params/mode.rs b/models/linear/src/params/mode.rs index 9d7e7367..a3c1e1bf 100644 --- a/models/linear/src/params/mode.rs +++ b/models/linear/src/params/mode.rs @@ -2,12 +2,15 @@ Appellation: mode Contrib: FL03 */ +use core::option::Option; -pub trait ParamMode: 'static { +pub trait Toggle: 'static {} + +pub trait ParamMode: Toggle { const BIASED: bool = false; fn is_biased(&self) -> bool { - Self::BIASED + core::any::type_name::() == core::any::type_name::() } private!(); @@ -39,6 +42,8 @@ macro_rules! mode { #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize,))] pub enum $T {} + impl Toggle for $T {} + impl ParamMode for $T { const BIASED: bool = $opt; @@ -52,7 +57,19 @@ macro_rules! mode { } + +macro_rules! impl_toggle { + ($($scope:ident$(<$T:ident>)?),* $(,)?) => { + $(impl_toggle!(@impl $scope$(<$T>)?);)* + }; + (@impl $scope:ident$(<$T:ident>)?) => { + impl$(<$T>)? Toggle for $scope$(<$T> where $T: 'static)? {} + }; +} + mode! { Biased: true, Unbiased: false, } + +impl_toggle!(bool, char, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, Option); \ No newline at end of file From 0f00cc43292f78d44b7417c6c45dafa5b83bb1d7 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Tue, 14 May 2024 15:34:52 -0500 Subject: [PATCH 10/23] update Signed-off-by: Joe McCain III --- core/src/macros.rs | 54 +++++++++++++++++-------- core/src/traits/misc/toggle.rs | 30 ++++++++++++++ core/src/traits/mod.rs | 4 +- core/src/traits/ops.rs | 4 +- models/linear/src/params/mode.rs | 14 +------ models/transformers/src/params/store.rs | 20 ++------- 6 files changed, 78 insertions(+), 48 deletions(-) create mode 100644 core/src/traits/misc/toggle.rs diff --git a/core/src/macros.rs b/core/src/macros.rs index 539e7a26..2e76f65d 100644 --- a/core/src/macros.rs +++ b/core/src/macros.rs @@ -202,10 +202,15 @@ macro_rules! getters { /// AS #[macro_export] macro_rules! dimensional { + + (dim: $name:ident$(())?) => { + /// Returns a reference to the current dimension, as a slice. + pub fn as_slice(&self) -> &[usize] { + self.$name$(())?.shape() + } - ($name:ident$(())?) => { - pub fn dim(&self) -> D::Pattern { - self.$name$(())?.dim() + pub fn into_pattern(self) -> D::Pattern { + self.$name$(())?.into_pattern() } pub fn ndim(&self) -> usize { @@ -213,29 +218,46 @@ macro_rules! dimensional { } pub fn raw_dim(&self) -> D { - self.$name$(())?.dim() + self.$name$(())?.dim().clone() + } + }; + + + ($name:ident) => { + /// Return the [pattern](ndarray::Dimension::Pattern) of the dimension + pub fn dim(&self) -> D::Pattern { + self.$name.dim() + } + /// Returns rank (ndim) of the dimension + pub fn ndim(&self) -> usize { + self.$name.ndim() + } + /// Returns the raw dimension [D](ndarray::Dimension) + pub fn raw_dim(&self) -> D { + self.$name.dim() } /// Returns a reference to the current dimension, as a slice. pub fn shape(&self) -> &[usize] { - self.$name$(())?.shape() + self.$name.shape() } }; - (dim: $name:ident$(())?) => { - /// Returns a reference to the current dimension, as a slice. - pub fn as_slice(&self) -> &[usize] { - self.$name$(())?.shape() - } - pub fn into_pattern(self) -> D::Pattern { - self.$name$(())?.into_pattern() + ($name:ident()) => { + /// Return the [pattern](ndarray::Dimension::Pattern) of the dimension + pub fn dim(&self) -> D::Pattern { + self.$name().dim() } - + /// Returns rank (ndim) of the dimension pub fn ndim(&self) -> usize { - self.$name$(())?.ndim() + self.$name().ndim() } - + /// Returns the raw dimension [D](ndarray::Dimension) pub fn raw_dim(&self) -> D { - self.$name$(())?.dim().clone() + self.$name().raw_dim() + } + /// Returns a reference to the current dimension, as a slice. + pub fn shape(&self) -> &[usize] { + self.$name().shape() } }; } diff --git a/core/src/traits/misc/toggle.rs b/core/src/traits/misc/toggle.rs new file mode 100644 index 00000000..50adf421 --- /dev/null +++ b/core/src/traits/misc/toggle.rs @@ -0,0 +1,30 @@ +/* + Appellation: toggle + Contrib: FL03 +*/ + +pub trait Toggle: 'static {} + +pub trait Mode: Toggle { + + fn of() -> bool + where + K: Toggle, + { + core::any::TypeId::of::() == core::any::TypeId::of::() + } +} + +/* + ************* Implementations ************* +*/ +macro_rules! impl_toggle { + ($($scope:ident$(<$T:ident>)?),* $(,)?) => { + $(impl_toggle!(@impl $scope$(<$T>)?);)* + }; + (@impl $scope:ident$(<$T:ident>)?) => { + impl$(<$T>)? Toggle for $scope$(<$T> where $T: 'static)? {} + }; +} + +impl_toggle!(bool, char, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, Option); \ No newline at end of file diff --git a/core/src/traits/mod.rs b/core/src/traits/mod.rs index c8e1a33d..350e5cfa 100644 --- a/core/src/traits/mod.rs +++ b/core/src/traits/mod.rs @@ -26,18 +26,20 @@ pub mod arr { } } -pub(crate) mod misc { +pub mod misc { pub mod adjust; #[doc(hidden)] pub mod container; pub mod setup; pub mod store; + pub mod toggle; pub(crate) mod prelude { pub use super::adjust::*; pub use super::container::*; pub use super::setup::*; pub use super::store::*; + pub use super::toggle::*; } } diff --git a/core/src/traits/ops.rs b/core/src/traits/ops.rs index 68dce74b..48e1cf3a 100644 --- a/core/src/traits/ops.rs +++ b/core/src/traits/ops.rs @@ -15,7 +15,7 @@ pub trait Apply { F: FnMut(T) -> U; } -pub trait ApplyOnce { +pub trait ApplyOn { type Output; fn apply(self, f: F) -> Self::Output @@ -32,7 +32,7 @@ pub trait Transform { /* ************* Implementations ************* */ -impl ApplyOnce for S +impl ApplyOn for S where S: Iterator, { diff --git a/models/linear/src/params/mode.rs b/models/linear/src/params/mode.rs index a3c1e1bf..2770830b 100644 --- a/models/linear/src/params/mode.rs +++ b/models/linear/src/params/mode.rs @@ -2,10 +2,9 @@ Appellation: mode Contrib: FL03 */ +use concision::Toggle; use core::option::Option; -pub trait Toggle: 'static {} - pub trait ParamMode: Toggle { const BIASED: bool = false; @@ -57,19 +56,8 @@ macro_rules! mode { } - -macro_rules! impl_toggle { - ($($scope:ident$(<$T:ident>)?),* $(,)?) => { - $(impl_toggle!(@impl $scope$(<$T>)?);)* - }; - (@impl $scope:ident$(<$T:ident>)?) => { - impl$(<$T>)? Toggle for $scope$(<$T> where $T: 'static)? {} - }; -} - mode! { Biased: true, Unbiased: false, } -impl_toggle!(bool, char, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, Option); \ No newline at end of file diff --git a/models/transformers/src/params/store.rs b/models/transformers/src/params/store.rs index 4a3eda39..06b0f22c 100644 --- a/models/transformers/src/params/store.rs +++ b/models/transformers/src/params/store.rs @@ -60,6 +60,7 @@ where { (self.q.view(), self.k.view(), self.v.view()) } + /// Consumes the current parameters, returning a three-tuple the Q, K, and V matrices respectivley. pub fn into_qkv(self) -> (ArrayBase, ArrayBase, ArrayBase) where @@ -67,22 +68,9 @@ where { (self.q, self.k, self.v) } - /// Return the [pattern](ndarray::Dimension::Pattern) of the dimension - pub fn dim(&self) -> D::Pattern { - self.q().dim() - } - /// Get the rank of the parameters; i.e. the number of dimensions. - pub fn rank(&self) -> usize { - self.q().ndim() - } - /// Returns the raw dimension ([D](ndarray::Dimension)) of the parameters. - pub fn raw_dim(&self) -> D { - self.q().raw_dim() - } - /// Returns a slice of the current shape of the parameters. - pub fn shape(&self) -> &[usize] { - self.q().shape() - } + + concision::dimensional!(q()); + ndview!(into_owned::(self) where A: Clone, S: Data); ndview!(to_owned::(&self) where A: Clone, S: Data); From 05547a38783978ddd5ee09461d11b9db3d631704 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Fri, 17 May 2024 08:27:22 -0500 Subject: [PATCH 11/23] update Signed-off-by: Joe McCain III --- models/transformers/src/macros.rs | 2 +- models/transformers/src/params/store.rs | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/models/transformers/src/macros.rs b/models/transformers/src/macros.rs index aabf8345..95523baf 100644 --- a/models/transformers/src/macros.rs +++ b/models/transformers/src/macros.rs @@ -12,7 +12,7 @@ macro_rules! ndbuilder { }; (@impl $method:ident::$call:ident() where $($rest:tt)*) => { pub fn $method>(shape: Sh) -> Self where $($rest)* { - Self::builder(shape, ArrayBase::$call) + Self::builder(shape, ndarray::ArrayBase::$call) } }; } diff --git a/models/transformers/src/params/store.rs b/models/transformers/src/params/store.rs index 06b0f22c..dcaa2e97 100644 --- a/models/transformers/src/params/store.rs +++ b/models/transformers/src/params/store.rs @@ -2,8 +2,8 @@ Appellation: params Contrib: FL03 */ +use concision::{dimensional, getters}; use nd::*; - use num::traits::{One, Zero}; pub struct ParamsBase, D = Ix2> @@ -34,12 +34,6 @@ where } } - ndbuilder!(new::default() where A: Default, S: DataOwned); - ndbuilder!(ones() where A: Clone + One, S: DataOwned); - ndbuilder!(zeros() where A: Clone + Zero, S: DataOwned); - - concision::getters!(q, k, v => ArrayBase); - pub fn from_elem(shape: Sh, value: A) -> Self where Sh: ShapeBuilder, @@ -69,7 +63,13 @@ where (self.q, self.k, self.v) } - concision::dimensional!(q()); + ndbuilder!(new::default() where A: Default, S: DataOwned); + ndbuilder!(ones() where A: Clone + One, S: DataOwned); + ndbuilder!(zeros() where A: Clone + Zero, S: DataOwned); + + getters!(q, k, v => ArrayBase); + + dimensional!(q()); ndview!(into_owned::(self) where A: Clone, S: Data); ndview!(to_owned::(&self) where A: Clone, S: Data); From 86fdb3f382807c241404536147e9050aec729ad6 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Fri, 17 May 2024 11:38:08 -0500 Subject: [PATCH 12/23] update Signed-off-by: Joe McCain III --- Cargo.toml | 3 +- core/src/func/activate.rs | 7 +- core/src/func/activate/binary.rs | 2 +- core/src/func/activate/nl.rs | 7 +- core/src/macros.rs | 35 ++-- core/src/math/traits.rs | 64 ++++---- core/src/nn/mod.rs | 4 + core/src/traits/misc/toggle.rs | 21 ++- models/linear/Cargo.toml | 17 +- models/linear/src/lib.rs | 2 + models/linear/src/macros.rs | 24 +++ models/linear/src/model/config.rs | 26 ++- models/linear/src/model/layout/layout.rs | 11 ++ models/linear/src/model/linear.rs | 49 ++++-- models/linear/src/norm/layer.rs | 89 ----------- models/linear/src/norm/layer/config.rs | 116 ++++++++++++++ models/linear/src/norm/layer/mod.rs | 51 ++++++ models/linear/src/norm/layer/model.rs | 169 ++++++++++++++++++++ models/linear/src/norm/mod.rs | 4 +- models/linear/src/params/mode.rs | 3 +- models/linear/src/params/store.rs | 53 +++--- models/linear/src/utils.rs | 4 +- models/linear/tests/{model.rs => linear.rs} | 0 models/linear/tests/norm.rs | 21 +++ models/transformers/src/params/store.rs | 6 +- 25 files changed, 601 insertions(+), 187 deletions(-) delete mode 100644 models/linear/src/norm/layer.rs create mode 100644 models/linear/src/norm/layer/config.rs create mode 100644 models/linear/src/norm/layer/mod.rs create mode 100644 models/linear/src/norm/layer/model.rs rename models/linear/tests/{model.rs => linear.rs} (100%) create mode 100644 models/linear/tests/norm.rs diff --git a/Cargo.toml b/Cargo.toml index 0f3ef677..90f493ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,9 +16,10 @@ version = "0.1.14" scsys = { default-features = false, branch = "v0.2.3", git = "https://github.com/scattered-systems/scsys.git", version = "0.2" } approx = "0.5" -itertools = "0.12" +itertools = "0.13" lazy_static = "1" ndarray = { default-features = false, version = "0.15" } +ndarray-stats = "0.5" num = { default-features = false, version = "0.4" } paste = "1" smart-default = "0.7" diff --git a/core/src/func/activate.rs b/core/src/func/activate.rs index 5ab8d035..fcedda7a 100644 --- a/core/src/func/activate.rs +++ b/core/src/func/activate.rs @@ -14,7 +14,7 @@ where x.clone() } -build_unary_trait!(LinearActivation.linear); +unary!(LinearActivation::linear(&self)); impl LinearActivation for T where @@ -33,6 +33,9 @@ pub(crate) mod prelude { pub use super::{linear, LinearActivation}; } -pub trait Activator { +#[doc(hidden)] +pub trait Evaluate { type Output; + + fn eval(&self, args: T) -> Self::Output; } diff --git a/core/src/func/activate/binary.rs b/core/src/func/activate/binary.rs index b45a0588..d02bb434 100644 --- a/core/src/func/activate/binary.rs +++ b/core/src/func/activate/binary.rs @@ -17,7 +17,7 @@ where } } -build_unary_trait!(Heavyside.heavyside,); +unary!(Heavyside::heavyside(&self),); macro_rules! impl_heavyside { ($($ty:ty),* $(,)*) => { diff --git a/core/src/func/activate/nl.rs b/core/src/func/activate/nl.rs index 93341e6f..da1cccce 100644 --- a/core/src/func/activate/nl.rs +++ b/core/src/func/activate/nl.rs @@ -54,7 +54,12 @@ where args.tanh() } -build_unary_trait!(ReLU.relu, Sigmoid.sigmoid, Softmax.softmax, Tanh.tanh,); +unary!( + ReLU::relu(&self), + Sigmoid::sigmoid(&self), + Softmax::softmax(&self), + Tanh::tanh(&self), +); /* ********** Implementations ********** diff --git a/core/src/macros.rs b/core/src/macros.rs index 2e76f65d..12e5e9f5 100644 --- a/core/src/macros.rs +++ b/core/src/macros.rs @@ -70,10 +70,10 @@ macro_rules! variant_constructor { } macro_rules! impl_unary { - ($name:ident.$call:ident<$T:ty>($f:expr) $($rest:tt)*) => { - impl_unary!(@impl $name.$call<$T>($f) $($rest)*); + ($name:ident::$call:ident<$T:ty>($f:expr) $($rest:tt)*) => { + impl_unary!(@impl $name::$call<$T>($f) $($rest)*); }; - (@impl $name:ident.$call:ident<$T:ty>($f:expr)) => { + (@impl $name:ident::$call:ident<$T:ty>($f:expr)) => { impl $name for $T { type Output = $T; @@ -84,13 +84,30 @@ macro_rules! impl_unary { }; } -macro_rules! build_unary_trait { - ($($name:ident.$call:ident),* $(,)?) => { +macro_rules! unary { + ($($name:ident::$call:ident),* $(,)?) => { $( - build_unary_trait!(@impl $name.$call); + unary!(@impl $name::$call(self)); )* }; - (@impl $name:ident.$call:ident) => { + ($($name:ident::$call:ident(self)),* $(,)?) => { + $( + unary!(@impl $name::$call(self)); + )* + }; + ($($name:ident::$call:ident(&self)),* $(,)?) => { + $( + unary!(@impl $name::$call(&self)); + )* + }; + (@impl $name:ident::$call:ident(self)) => { + pub trait $name { + type Output; + + fn $call(self) -> Self::Output; + } + }; + (@impl $name:ident::$call:ident(&self)) => { pub trait $name { type Output; @@ -202,7 +219,7 @@ macro_rules! getters { /// AS #[macro_export] macro_rules! dimensional { - + (dim: $name:ident$(())?) => { /// Returns a reference to the current dimension, as a slice. pub fn as_slice(&self) -> &[usize] { @@ -222,7 +239,7 @@ macro_rules! dimensional { } }; - + ($name:ident) => { /// Return the [pattern](ndarray::Dimension::Pattern) of the dimension pub fn dim(&self) -> D::Pattern { diff --git a/core/src/math/traits.rs b/core/src/math/traits.rs index e817ecda..718b4275 100644 --- a/core/src/math/traits.rs +++ b/core/src/math/traits.rs @@ -5,52 +5,60 @@ use nd::{Array, ArrayBase, Data, Dimension}; use num::complex::{Complex, ComplexFloat}; -macro_rules! unary { - ($($name:ident::$method:ident),*) => { - $(unary!(@impl $name::$method);)* - }; - (@impl $name:ident::$method:ident) => { - pub trait $name { - type Output; - - fn $method(self) -> Self::Output; - } - }; - (@fn $($method:ident),* $(,)?) => { - $(fn $method(self) -> Self::Output;)* - }; -} - -unary!(Abs::abs, SquareRoot::sqrt); +unary!( + Abs::abs(self), + Cos::cos(self), + Cosh::cosh(self), + Sine::sin(self), + Sinh::sinh(self), + SquareRoot::sqrt(self) +); /* ********* Implementations ********* */ -macro_rules! fwd_unop { + +macro_rules! unary_impl { ($name:ident::$method:ident<[$($T:ty),* $(,)?]>) => { - fwd_unop!($name::$method.$method<[$($T: $T),*]>); + unary_impl!(@loop $name::$method<[$($T),*]>); + }; + ($($name:ident::$method:ident<$T:ty$(, Output = $O:ty)?>),* $(,)?) => { + $(unary_impl!(@impl $name::$method<$T$(, Output = $O>)?);)* }; - ($name:ident::$method:ident.$call:ident<[$($T:ty: $O:ty),* $(,)?]>) => { - $(fwd_unop!(@impl $name::$method.$call<$T> -> $O);)* + ($($name:ident::$method:ident<$T:ty, Output = $O:ty>),* $(,)?) => { + $(unary_impl!(@impl $name::$method<$T, Output = $O>);)* }; - (@impl $name:ident::$method:ident$(.$call:ident)?<$T:ty>) => { - fwd_unop!(@impl $name::$method$(.$call)?<$T> -> $T); + (@loop $name:ident::$method:ident<[$($T:ty),* $(,)?]>) => { + $(unary_impl!(@impl $name::$method<$T>);)* }; - (@impl $name:ident::$method:ident<$T:ty> -> $O:ty) => { - fwd_unop!(@impl $name::$method.$method<$T> -> $O); + (@impl $name:ident::$method:ident<$T:ty>) => { + unary_impl!(@impl $name::$method<$T, Output = $T>); }; - (@impl $name:ident::$method:ident.$call:ident<$T:ty> -> $O:ty) => { + (@impl $name:ident::$method:ident<$T:ty, Output = $O:ty>) => { impl $name for $T { type Output = $O; fn $method(self) -> Self::Output { - <$T>::$call(self) + <$T>::$method(self) } } }; } -fwd_unop!(SquareRoot::sqrt<[f32, f64]>); +macro_rules! unary_impls { + ($($name:ident::$method:ident<[$($T:ty),* $(,)?]>),* $(,)?) => { + $(unary_impl!(@loop $name::$method<[$($T),*]>);)* + }; +} + +unary_impls!( + Abs::abs<[f32, f64]>, + Cosh::cosh<[f32, f64]>, + Cos::cos<[f32, f64]>, + Sinh::sinh<[f32, f64]>, + Sine::sin<[f32, f64]>, + SquareRoot::sqrt<[f32, f64]> +); impl SquareRoot for Complex where diff --git a/core/src/nn/mod.rs b/core/src/nn/mod.rs index b5eb0d61..c0eb1f81 100644 --- a/core/src/nn/mod.rs +++ b/core/src/nn/mod.rs @@ -12,5 +12,9 @@ pub(crate) mod prelude { pub use super::model::prelude::*; } +#[cfg(any(feature = "alloc", feature = "std"))] +pub type ForwardDyn, O = T> = + crate::rust::Box>; + #[cfg(test)] mod tests {} diff --git a/core/src/traits/misc/toggle.rs b/core/src/traits/misc/toggle.rs index 50adf421..0bca252e 100644 --- a/core/src/traits/misc/toggle.rs +++ b/core/src/traits/misc/toggle.rs @@ -6,7 +6,6 @@ pub trait Toggle: 'static {} pub trait Mode: Toggle { - fn of() -> bool where K: Toggle, @@ -16,7 +15,7 @@ pub trait Mode: Toggle { } /* - ************* Implementations ************* + ************* Implementations ************* */ macro_rules! impl_toggle { ($($scope:ident$(<$T:ident>)?),* $(,)?) => { @@ -27,4 +26,20 @@ macro_rules! impl_toggle { }; } -impl_toggle!(bool, char, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, Option); \ No newline at end of file +impl_toggle!( + bool, + char, + i8, + i16, + i32, + i64, + i128, + isize, + u8, + u16, + u32, + u64, + u128, + usize, + Option +); diff --git a/models/linear/Cargo.toml b/models/linear/Cargo.toml index ee925322..2122d7eb 100644 --- a/models/linear/Cargo.toml +++ b/models/linear/Cargo.toml @@ -87,7 +87,11 @@ doctest = true test = true [[test]] -name = "model" +name = "linear" +required-features = ["std"] + +[[test]] +name = "norm" required-features = ["std"] [[test]] @@ -106,6 +110,11 @@ strum.workspace = true optional = true version = "0.5" +[dependencies.concision-core] +default-features = false +path = "../../core" +version = "0.1.14" + [dependencies.serde] default-features = false features = ["derive"] @@ -116,12 +125,6 @@ version = "1" optional = true version = "0.1" - -[dependencies.concision-core] -default-features = false -path = "../../core" -version = "0.1.14" - [dev-dependencies] lazy_static.workspace = true diff --git a/models/linear/src/lib.rs b/models/linear/src/lib.rs index 2b011f81..6f457ab7 100644 --- a/models/linear/src/lib.rs +++ b/models/linear/src/lib.rs @@ -14,8 +14,10 @@ extern crate alloc; extern crate concision_core as concision; extern crate ndarray as nd; +// extern crate ndarray_stats as ndstats; pub use self::model::{Config, Features, Layout, Linear}; +pub use self::norm::LayerNorm; pub use self::params::{mode::*, ParamsBase}; #[allow(unused_imports)] pub use self::{primitives::*, traits::*, utils::*}; diff --git a/models/linear/src/macros.rs b/models/linear/src/macros.rs index ab2ee6eb..6fad31d9 100644 --- a/models/linear/src/macros.rs +++ b/models/linear/src/macros.rs @@ -25,6 +25,30 @@ macro_rules! impl_param_builder { }; } +macro_rules! impl_model_builder { + ($method:ident$(.$call:ident)? where $($rest:tt)*) => { + impl_model_builder!(@impl $method$(.$call)? where $($rest)*); + }; + (@impl $method:ident where $($rest:tt)*) => { + impl_model_builder!(@impl $method.$method where $($rest)*); + }; + (@impl $method:ident.$call:ident where $($rest:tt)*) => { + pub fn $method(shape: Sh) -> Self + where + K: $crate::params::mode::ParamMode, + Sh: ndarray::ShapeBuilder, + $($rest)* + { + let config = $crate::model::Config::::new().with_shape(shape); + let params = $crate::params::ParamsBase::$call(config.dim()); + $crate::model::Linear { + config, + params, + } + } + }; +} + macro_rules! ndview { ($method:ident::$($rest:tt)*) => { ndview!(@impl $method.$method::$($rest)*); diff --git a/models/linear/src/model/config.rs b/models/linear/src/model/config.rs index 1c558441..ad6f9347 100644 --- a/models/linear/src/model/config.rs +++ b/models/linear/src/model/config.rs @@ -79,14 +79,19 @@ where _biased: self._biased, } } - /// Returns a cloned reference to the [dimension](ndarray::Dimension) of the [layout](Layout) - pub fn dim(&self) -> D { - self.layout().dim() - } - pub fn into_pattern(self) -> D::Pattern { - self.dim().into_pattern() + pub fn with_shape(self, shape: Sh) -> Config + where + E: RemoveAxis, + Sh: ShapeBuilder, + { + Config { + layout: self.layout.with_shape(shape), + name: self.name, + _biased: self._biased, + } } + /// This function attempts to convert the [layout](Layout) of the [Config] into a new [dimension](ndarray::Dimension) pub fn into_dimensionality(self, dim: E) -> Result, nd::ShapeError> where @@ -122,6 +127,15 @@ where &self.name } + /// Returns a cloned reference to the [dimension](ndarray::Dimension) of the [layout](Layout) + pub fn dim(&self) -> D { + self.layout().dim() + } + + pub fn into_pattern(self) -> D::Pattern { + self.dim().into_pattern() + } + pub fn ndim(&self) -> usize { self.layout().ndim() } diff --git a/models/linear/src/model/layout/layout.rs b/models/linear/src/model/layout/layout.rs index 47a1ef8f..0cf5df2f 100644 --- a/models/linear/src/model/layout/layout.rs +++ b/models/linear/src/model/layout/layout.rs @@ -41,6 +41,17 @@ where Self { dim, features } } + pub fn with_shape(self, shape: Sh) -> Layout + where + E: RemoveAxis, + Sh: ShapeBuilder, + { + let shape = shape.into_shape(); + let dim = shape.raw_dim().clone(); + let features = Features::from_shape(dim.clone()); + Layout { dim, features } + } + pub fn as_slice(&self) -> &[usize] { self.dim.slice() } diff --git a/models/linear/src/model/linear.rs b/models/linear/src/model/linear.rs index 43144147..55984521 100644 --- a/models/linear/src/model/linear.rs +++ b/models/linear/src/model/linear.rs @@ -5,7 +5,8 @@ use super::{Config, Layout}; use crate::{Biased, LinearParams, ParamMode, Unbiased}; use concision::prelude::{Predict, Result}; -use nd::{Array, Dimension, Ix2, RemoveAxis}; +use nd::prelude::*; +use nd::RemoveAxis; /// An implementation of a linear model. /// @@ -24,6 +25,10 @@ impl Linear where D: RemoveAxis, { + impl_model_builder!(default where A: Default); + impl_model_builder!(ones where A: Clone + num::One); + impl_model_builder!(zeros where A: Clone + num::Zero); + pub fn from_config(config: Config) -> Self where A: Clone + Default, @@ -43,14 +48,7 @@ where Self { config, params } } - pub fn with_params(self, params: LinearParams) -> Linear - where - E: RemoveAxis, - { - let config = self.config.into_dimensionality(params.raw_dim()).unwrap(); - Linear { config, params } - } - + /// Applies an activcation function onto the prediction of the model. pub fn activate(&self, args: &X, func: F) -> Result where F: for<'a> Fn(&'a Y) -> Y, @@ -108,6 +106,14 @@ where self.config().is_biased() } + pub fn with_params(self, params: LinearParams) -> Linear + where + E: RemoveAxis, + { + let config = self.config.into_dimensionality(params.raw_dim()).unwrap(); + Linear { config, params } + } + pub fn with_name(self, name: impl ToString) -> Self { Self { config: self.config.with_name(name), @@ -120,6 +126,16 @@ impl Linear where D: RemoveAxis, { + pub fn biased(shape: Sh) -> Self + where + A: Default, + Sh: ShapeBuilder, + { + let config = Config::::new().with_shape(shape); + let params = LinearParams::biased(config.dim()); + Linear { config, params } + } + pub fn bias(&self) -> &Array { self.params().bias() } @@ -128,3 +144,18 @@ where self.params_mut().bias_mut() } } + +impl Linear +where + D: RemoveAxis, +{ + pub fn unbiased(shape: Sh) -> Self + where + A: Default, + Sh: ShapeBuilder, + { + let config = Config::::new().with_shape(shape); + let params = LinearParams::unbiased(config.dim()); + Linear { config, params } + } +} diff --git a/models/linear/src/norm/layer.rs b/models/linear/src/norm/layer.rs deleted file mode 100644 index 4f250a87..00000000 --- a/models/linear/src/norm/layer.rs +++ /dev/null @@ -1,89 +0,0 @@ -/* - Appellation: layer - Contrib: FL03 -*/ -use super::EPSILON; -use crate::{LinearParams, ParamMode}; -use nd::prelude::*; -use nd::RemoveAxis; - -// #62 -/// [LayerNorm] adhears to the [Layer Normalization](https://arxiv.org/abs/1607.06450) algorithm. -/// -/// ### Resources -pub struct LayerNorm -where - D: Dimension, -{ - config: LayerNormConfig, - params: LinearParams, -} - -pub struct LayerNormConfig { - pub dim: D, - pub eps: f64, -} - -impl LayerNormConfig -where - D: Dimension, -{ - pub fn new() -> Self { - Self { - dim: D::default(), - eps: 1e-5, - } - } - - pub fn create(dim: D, eps: f64) -> Self - where - D: Default, - { - Self { dim, eps } - } - - pub fn with_dim(dim: D) -> Self { - Self { dim, eps: 1e-5 } - } -} - -impl Default for LayerNormConfig -where - D: Default, -{ - fn default() -> Self { - Self { - dim: D::default(), - eps: EPSILON, - } - } -} - -impl LayerNorm -where - D: RemoveAxis, - K: ParamMode, -{ - pub fn from_shape(shape: Sh) -> Self - where - A: Default, - Sh: ShapeBuilder, - { - let dim = shape.into_shape().raw_dim().clone(); - let config = LayerNormConfig::with_dim(dim.clone()); - let params = LinearParams::::default(dim); - Self { config, params } - } - - pub fn config(&self) -> &LayerNormConfig { - &self.config - } - - pub fn params(&self) -> &LinearParams { - &self.params - } - - pub fn params_mut(&mut self) -> &mut LinearParams { - &mut self.params - } -} diff --git a/models/linear/src/norm/layer/config.rs b/models/linear/src/norm/layer/config.rs new file mode 100644 index 00000000..ef9e7469 --- /dev/null +++ b/models/linear/src/norm/layer/config.rs @@ -0,0 +1,116 @@ +/* + Appellation: config + Contrib: FL03 +*/ +use super::EPSILON; +use nd::prelude::{Axis, Dimension, Ix2}; + +pub struct Config { + pub axis: Option, + pub dim: D, + pub eps: f64, +} + +impl Config +where + D: Dimension, +{ + pub fn new() -> ConfigBuilder { + ConfigBuilder::new() + } + + pub fn axis(&self) -> Option<&Axis> { + self.axis.as_ref() + } + + pub fn axis_mut(&mut self) -> &mut Option { + &mut self.axis + } + + pub fn eps(&self) -> f64 { + self.eps + } + + pub fn eps_mut(&mut self) -> &mut f64 { + &mut self.eps + } + + pub fn dim(&self) -> D::Pattern { + self.raw_dim().into_pattern() + } + + pub fn dim_mut(&mut self) -> &mut D { + &mut self.dim + } + + pub fn ndim(&self) -> usize { + self.dim.ndim() + } + + pub fn raw_dim(&self) -> D { + self.dim.clone() + } + + pub fn shape(&self) -> &[usize] { + self.dim.slice() + } + + pub fn shape_mut(&mut self) -> &mut [usize] { + self.dim.slice_mut() + } +} + +impl Default for Config +where + D: Default, +{ + fn default() -> Self { + Self { + axis: None, + dim: D::default(), + eps: EPSILON, + } + } +} + +pub struct ConfigBuilder { + axis: Option, + dim: D, + eps: f64, +} + +impl ConfigBuilder +where + D: Dimension, +{ + pub fn new() -> Self { + Self { + axis: None, + dim: D::default(), + eps: 1e-5, + } + } + + pub fn axis(mut self, axis: Axis) -> Self { + self.axis = Some(axis); + self + } + + pub fn dim(mut self, dim: D) -> Self { + self.dim = dim; + self + } + + pub fn eps(mut self, eps: f64) -> Self { + self.eps = eps; + self + } + + pub fn build(self) -> Config { + Config { + axis: self.axis, + dim: self.dim, + eps: self.eps, + } + } +} diff --git a/models/linear/src/norm/layer/mod.rs b/models/linear/src/norm/layer/mod.rs new file mode 100644 index 00000000..6b54d6e8 --- /dev/null +++ b/models/linear/src/norm/layer/mod.rs @@ -0,0 +1,51 @@ +/* + Appellation: layer + Contrib: FL03 +*/ +//! # Layer Normalization +//! +//! This module provides the necessary tools for creating and training layer normalization layers. +pub(crate) use self::utils::*; +pub use self::{config::*, model::*}; + +pub(crate) mod config; +pub(crate) mod model; + +pub const EPSILON: f64 = 1e-5; + +pub(crate) mod prelude { + pub use super::config::Config as LayerNormConfig; + pub use super::model::LayerNorm; +} + +pub(crate) mod utils { + use nd::{Array, Axis, Dimension, RemoveAxis}; + use num::traits::{Float, FromPrimitive}; + + pub(crate) fn layer_norm(x: &Array, eps: f64) -> Array + where + A: Float + FromPrimitive, + D: Dimension, + { + let mean = x.mean().unwrap(); + let denom = { + let eps = A::from(eps).unwrap(); + let var = x.var(A::zero()); + (var + eps).sqrt() + }; + x.mapv(|xi| (xi - mean) / denom) + } + + pub(crate) fn layer_norm_axis(x: &Array, axis: Axis, eps: f64) -> Array + where + A: Float + FromPrimitive, + D: RemoveAxis, + { + let eps = A::from(eps).unwrap(); + let mean = x.mean_axis(axis).unwrap(); + let var = x.var_axis(axis, A::zero()); + let inv_std = var.mapv(|v| (v + eps).recip().sqrt()); + let x_norm = (x - &mean) * &inv_std; + x_norm + } +} diff --git a/models/linear/src/norm/layer/model.rs b/models/linear/src/norm/layer/model.rs new file mode 100644 index 00000000..979802b6 --- /dev/null +++ b/models/linear/src/norm/layer/model.rs @@ -0,0 +1,169 @@ +/* + Appellation: layer + Contrib: FL03 +*/ +use super::Config; +use crate::{Biased, LinearParams, ParamMode, Unbiased}; +use concision::Forward; +use nd::prelude::*; +use nd::RemoveAxis; +use num::traits::{Float, FromPrimitive, One, Zero}; + +// #62 +/// +/// Layer Normalization directly estimates the normalization statistics from the summed inputs +/// to the neurons within a _hidden_ layer, eliminating the need to introduce any additional dependencies. +/// +/// [LayerNorm] follows the [Layer Normalization](https://arxiv.org/abs/1607.06450) paper. +/// +/// ### Resources +pub struct LayerNorm +where + D: Dimension, +{ + config: Config, + params: LinearParams, +} + +impl LayerNorm +where + D: RemoveAxis, + K: ParamMode, +{ + pub fn from_config(config: Config) -> Self + where + A: Default, + { + let params = LinearParams::::default(config.dim()); + Self { config, params } + } + + pub fn default(shape: Sh) -> Self + where + A: Default, + Sh: ShapeBuilder, + { + let dim = shape.into_shape().raw_dim().clone(); + let config = Config::new().dim(dim.clone()).build(); + let params = LinearParams::::default(dim); + Self { config, params } + } + + pub fn ones(shape: Sh) -> Self + where + A: Clone + One, + Sh: ShapeBuilder, + { + let dim = shape.into_shape().raw_dim().clone(); + let config = Config::new().dim(dim.clone()).build(); + let params = LinearParams::::ones(dim); + Self { config, params } + } + + pub fn zeros(shape: Sh) -> Self + where + A: Clone + Zero, + Sh: ShapeBuilder, + { + let dim = shape.into_shape().raw_dim().clone(); + let config = Config::new().dim(dim.clone()).build(); + let params = LinearParams::::zeros(dim); + Self { config, params } + } + + pub const fn config(&self) -> &Config { + &self.config + } + + pub fn is_biased(&self) -> bool { + self.params().is_biased() + } + /// Returns an immutable reference to the layer's parameters. + pub const fn params(&self) -> &LinearParams { + &self.params + } + /// Returns a mutable reference to the layer's parameters. + pub fn params_mut(&mut self) -> &mut LinearParams { + &mut self.params + } + + pub fn dim(&self) -> D::Pattern { + self.config().dim() + } + + pub fn eps(&self) -> f64 { + self.config().eps() + } + + pub fn ndim(&self) -> usize { + self.config().ndim() + } + + pub fn raw_dim(&self) -> D { + self.config().raw_dim() + } + + pub fn shape(&self) -> &[usize] { + self.config().shape() + } +} + +impl Default for LayerNorm +where + A: Default, + D: RemoveAxis, +{ + fn default() -> Self { + Self { + config: Config::default(), + params: Default::default(), + } + } +} + +impl Default for LayerNorm +where + A: Default, + D: RemoveAxis, +{ + fn default() -> Self { + Self { + config: Config::default(), + params: Default::default(), + } + } +} + +impl Forward> for LayerNorm +where + A: Float + FromPrimitive, + D: RemoveAxis, +{ + type Output = Array; + + fn forward(&self, x: &Array) -> Self::Output { + let norm = if let Some(axis) = self.config().axis() { + super::layer_norm_axis(x, *axis, self.eps()) + } else { + super::layer_norm(x, self.eps()) + }; + norm * self.params().weights() + self.params().bias() + } +} + +impl Forward> for LayerNorm +where + A: Float + FromPrimitive, + D: RemoveAxis, +{ + type Output = Array; + + fn forward(&self, x: &Array) -> Self::Output { + let norm = if let Some(axis) = self.config().axis() { + super::layer_norm_axis(x, *axis, self.eps()) + } else { + super::layer_norm(x, self.eps()) + }; + norm * self.params().weights() + } +} diff --git a/models/linear/src/norm/mod.rs b/models/linear/src/norm/mod.rs index 67fbca27..d8d607dd 100644 --- a/models/linear/src/norm/mod.rs +++ b/models/linear/src/norm/mod.rs @@ -9,8 +9,6 @@ pub use self::layer::LayerNorm; pub mod layer; -pub const EPSILON: f64 = 1e-5; - pub(crate) mod prelude { - pub use super::layer::LayerNorm; + pub use super::layer::prelude::*; } diff --git a/models/linear/src/params/mode.rs b/models/linear/src/params/mode.rs index 2770830b..ff153868 100644 --- a/models/linear/src/params/mode.rs +++ b/models/linear/src/params/mode.rs @@ -9,7 +9,7 @@ pub trait ParamMode: Toggle { const BIASED: bool = false; fn is_biased(&self) -> bool { - core::any::type_name::() == core::any::type_name::() + core::any::TypeId::of::() == core::any::TypeId::of::() } private!(); @@ -60,4 +60,3 @@ mode! { Biased: true, Unbiased: false, } - diff --git a/models/linear/src/params/store.rs b/models/linear/src/params/store.rs index b278d268..586b227f 100644 --- a/models/linear/src/params/store.rs +++ b/models/linear/src/params/store.rs @@ -7,6 +7,10 @@ use core::marker::PhantomData; use nd::*; use num::{One, Zero}; +/// The base paramter store for a linear model. +/// +/// [ParamsBase] works to store the weights and biases of a linear model. +/// The structure is parameterized over the type and dimension of the data as well as the current mode of the store. /// pub struct ParamsBase, D = Ix2, K = Unbiased> where @@ -71,34 +75,24 @@ where pub fn in_features(&self) -> usize { self.features().dmodel() } + + pub fn out_features(&self) -> usize { + if self.ndim() == 1 { + return 1; + } + self.shape()[1] + } /// Returns true if the parameter store is biased; /// Compares the [TypeId](core::any::TypeId) of the store with the [Biased](crate::Biased) type. pub fn is_biased(&self) -> bool where K: 'static, { - use core::any::TypeId; - TypeId::of::() == TypeId::of::() + crate::is_biased::() } - pub fn ndim(&self) -> usize { - self.weights().ndim() - } + concision::dimensional!(weights()); - pub fn out_features(&self) -> usize { - if self.ndim() == 1 { - return 1; - } - self.shape()[1] - } - /// Returns the raw dimension of the weights. - pub fn raw_dim(&self) -> D { - self.weights().raw_dim() - } - /// Returns the shape of the weights. - pub fn shape(&self) -> &[usize] { - self.weights().shape() - } ndview!(into_owned::(self) where A: Clone, S: Data); ndview!(into_shared::(self) where A: Clone, S: DataOwned); @@ -131,11 +125,11 @@ where _mode: PhantomData::, } } - + /// Return an unwraped, immutable reference to the bias array. pub fn bias(&self) -> &ArrayBase { self.bias.as_ref().unwrap() } - + /// Return an unwraped, mutable reference to the bias array. pub fn bias_mut(&mut self) -> &mut ArrayBase { self.bias.as_mut().unwrap() } @@ -190,6 +184,21 @@ where } } +impl Default for ParamsBase +where + A: Default, + D: Dimension, + S: DataOwned, +{ + fn default() -> Self { + Self { + bias: Some(Default::default()), + weights: Default::default(), + _mode: PhantomData::, + } + } +} + impl Default for ParamsBase where A: Default, @@ -200,7 +209,7 @@ where Self { bias: None, weights: Default::default(), - _mode: PhantomData, + _mode: PhantomData::, } } } diff --git a/models/linear/src/utils.rs b/models/linear/src/utils.rs index 445e0862..ca6142dc 100644 --- a/models/linear/src/utils.rs +++ b/models/linear/src/utils.rs @@ -3,7 +3,6 @@ Contrib: FL03 */ use crate::params::Biased; -use core::any::TypeId; use nd::{ArrayBase, Axis, Dimension, RawData, RemoveAxis}; /// A utilitarian funciton for building bias tensors. @@ -33,6 +32,9 @@ where } } +/// A utilitarian function for checking if a type is [Biased]; returns false otherwise. +/// Compares the [TypeId](core::any::TypeId) of `K` to the [TypeId](core::any::TypeId) of [Biased]. pub fn is_biased() -> bool { + use core::any::TypeId; TypeId::of::() == TypeId::of::() } diff --git a/models/linear/tests/model.rs b/models/linear/tests/linear.rs similarity index 100% rename from models/linear/tests/model.rs rename to models/linear/tests/linear.rs diff --git a/models/linear/tests/norm.rs b/models/linear/tests/norm.rs new file mode 100644 index 00000000..7ed0dfff --- /dev/null +++ b/models/linear/tests/norm.rs @@ -0,0 +1,21 @@ +/* + Appellation: norm + Contrib: FL03 +*/ +extern crate concision_core as concision; +extern crate concision_linear as linear; + +use concision::{linarr, Forward}; +use linear::{Biased, LayerNorm}; +use ndarray::prelude::*; + +#[test] +fn test_layer_norm() { + let shape = (3, 3); + let x = linarr::(shape).unwrap() + 1f64; + + let ln = LayerNorm::::ones(shape); + let y = ln.forward(&x); + + assert_eq!(y.dim(), shape); +} diff --git a/models/transformers/src/params/store.rs b/models/transformers/src/params/store.rs index dcaa2e97..fcda0e9e 100644 --- a/models/transformers/src/params/store.rs +++ b/models/transformers/src/params/store.rs @@ -54,7 +54,7 @@ where { (self.q.view(), self.k.view(), self.v.view()) } - + /// Consumes the current parameters, returning a three-tuple the Q, K, and V matrices respectivley. pub fn into_qkv(self) -> (ArrayBase, ArrayBase, ArrayBase) where @@ -62,13 +62,13 @@ where { (self.q, self.k, self.v) } - + ndbuilder!(new::default() where A: Default, S: DataOwned); ndbuilder!(ones() where A: Clone + One, S: DataOwned); ndbuilder!(zeros() where A: Clone + Zero, S: DataOwned); getters!(q, k, v => ArrayBase); - + dimensional!(q()); ndview!(into_owned::(self) where A: Clone, S: Data); From 86cf2ad1f4150c4d62c05a2ea41503f78bf3fdd4 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Fri, 17 May 2024 15:08:45 -0500 Subject: [PATCH 13/23] update Signed-off-by: Joe McCain III --- core/src/func/dropout.rs | 98 +++++++++++++++++++ core/src/func/mod.rs | 3 + core/src/macros.rs | 34 ++----- core/tests/func.rs | 16 +++ models/linear/Cargo.toml | 2 +- models/linear/src/impls/impl_rand.rs | 6 +- models/linear/src/impls/params/impl_from.rs | 8 +- models/linear/src/impls/params/impl_params.rs | 4 +- models/linear/src/impls/params/impl_serde.rs | 2 +- models/linear/src/macros.rs | 29 +++--- models/linear/src/params/store.rs | 62 ++++++++---- models/linear/src/traits.rs | 11 +++ models/linear/tests/norm.rs | 18 +++- models/transformers/src/codec/model.rs | 13 +-- 14 files changed, 227 insertions(+), 79 deletions(-) create mode 100644 core/src/func/dropout.rs create mode 100644 core/tests/func.rs diff --git a/core/src/func/dropout.rs b/core/src/func/dropout.rs new file mode 100644 index 00000000..969f57f5 --- /dev/null +++ b/core/src/func/dropout.rs @@ -0,0 +1,98 @@ +/* + Appellation: dropout + Contrib: FL03 +*/ +#![cfg(feature = "rand")] +use crate::Forward; +use nd::prelude::*; +use nd::{DataOwned, RemoveAxis, ScalarOperand}; +use ndrand::rand_distr::Bernoulli; +use ndrand::RandomExt; +use num::traits::Num; + +pub fn dropout(array: &ArrayBase, p: f64) -> Array +where + A: Num + ScalarOperand, + D: Dimension, + S: DataOwned, +{ + // Create a Bernoulli distribution for dropout + let distribution = Bernoulli::new(p).unwrap(); + + // Create a mask of the same shape as the input array + let mask: Array = Array::random(array.dim(), distribution); + let mask = mask.mapv(|x| if x { A::zero() } else { A::one() }); + + // Element-wise multiplication to apply dropout + array * mask +} + +pub fn dropout_axis(array: &ArrayBase, _axis: Axis, p: f64) -> Array +where + A: Num + ScalarOperand, + D: RemoveAxis, + S: DataOwned, +{ + // Create a Bernoulli distribution for dropout + let distribution = Bernoulli::new(p).unwrap(); + + // Create a mask of the same shape as the input array + let _mask: Array = Array::random(array.dim(), distribution); + + unimplemented!() +} + +pub struct Dropout { + axis: Option, + p: f64, +} + +impl Dropout { + pub fn new(p: f64) -> Self { + Self { axis: None, p } + } + + pub fn with_axis(self, axis: Axis) -> Self { + Self { + axis: Some(axis), + ..self + } + } + + pub fn dropout(&self, array: &ArrayBase) -> Array + where + A: Num + ScalarOperand, + D: Dimension, + S: DataOwned, + { + dropout(array, self.p) + } + + pub fn dropout_axis(&self, array: &ArrayBase) -> Array + where + A: Num + ScalarOperand, + D: RemoveAxis, + S: DataOwned, + { + dropout_axis(array, self.axis.unwrap(), self.p) + } +} + +impl Default for Dropout { + fn default() -> Self { + Self::new(0.5) + } +} + +impl Forward> for Dropout +where + A: Num + ScalarOperand, + D: Dimension, + S: DataOwned, +{ + type Output = Array; + + fn forward(&self, input: &ArrayBase) -> Self::Output { + dropout(input, self.p) + } +} diff --git a/core/src/func/mod.rs b/core/src/func/mod.rs index 266ce461..545415bd 100644 --- a/core/src/func/mod.rs +++ b/core/src/func/mod.rs @@ -6,9 +6,12 @@ pub use self::prelude::*; pub mod activate; +pub mod dropout; pub mod loss; pub(crate) mod prelude { pub use super::activate::prelude::*; + #[cfg(feature = "rand")] + pub use super::dropout::*; pub use super::loss::prelude::*; } diff --git a/core/src/macros.rs b/core/src/macros.rs index 12e5e9f5..33501dfd 100644 --- a/core/src/macros.rs +++ b/core/src/macros.rs @@ -118,36 +118,22 @@ macro_rules! unary { #[macro_export] macro_rules! builder { - ($(#[derive($($d:ident),*)])?$name:ident::<$inner:ty> {$($k:ident: $v:ty),* $(,)?}) => { - $crate::builder!(@loop $(#[derive($($d),*)])?$name::<$inner> {$($k: $v),*}); + ($(#[derive($($d:ident),+)])?$name:ident::<$inner:ty> {$($k:ident: $v:ty),* $(,)?}) => { + $crate::builder!(@loop builder: $name, derive: [$($($d),+)?], inner: $inner {$($k: $v),*}); }; - (@loop #[derive($($d:ident),*)] $name:ident::<$inner:ty> {$($k:ident: $v:ty),* $(,)?}) => { - pub struct $name { - inner: $inner, - } - - impl $name { - pub fn new() -> Self { - Self { inner: Default::default() } - } - - pub fn build(self) -> $inner { - self.inner - } - - $( - pub fn $k(mut self, $k: $v) -> Self { - self.inner.$k = $k; - self - } - )* - } + ($(#[derive($($d:ident),+)])? $name:ident($inner:ty) {$($k:ident: $v:ty),* $(,)?}) => { + $crate::builder!(@loop builder: $name, derive: [$($($d),+)?], inner: $inner {$($k: $v),*}); }; - (@loop $name:ident::<$inner:ty> {$($k:ident: $v:ty),* $(,)?}) => { + (@loop builder: $name:ident, derive: [$($d:ident),* $(,)?], inner: $inner:ty {$($k:ident: $v:ty),* $(,)?}) => { + + #[derive(Default, $($d),*)] pub struct $name { inner: $inner, } + $crate::builder!(@impl builder: $name, inner: $inner {$($k: $v),*}); + }; + (@impl builder: $name:ident, inner: $inner:ty {$($k:ident: $v:ty),* $(,)?}) => { impl $name { pub fn new() -> Self { Self { diff --git a/core/tests/func.rs b/core/tests/func.rs new file mode 100644 index 00000000..2f888a90 --- /dev/null +++ b/core/tests/func.rs @@ -0,0 +1,16 @@ +#![allow(unused_imports)] +extern crate concision_core as concision; + +use concision::func::Dropout; +use concision::Forward; +use ndarray::prelude::*; + +#[test] +#[cfg(feature = "rand")] +fn test_dropout() { + let arr = Array2::::ones((2, 2)); + assert!(arr.iter().all(|&x| x == 1.0)); + let dropout = Dropout::new(0.5); + let res = dropout.forward(&arr); + assert!(res.iter().any(|&x| x == 0.0)); +} diff --git a/models/linear/Cargo.toml b/models/linear/Cargo.toml index 2122d7eb..ab7770d1 100644 --- a/models/linear/Cargo.toml +++ b/models/linear/Cargo.toml @@ -92,7 +92,7 @@ required-features = ["std"] [[test]] name = "norm" -required-features = ["std"] +required-features = ["approx", "std"] [[test]] name = "params" diff --git a/models/linear/src/impls/impl_rand.rs b/models/linear/src/impls/impl_rand.rs index e1a6fc6b..d571afd3 100644 --- a/models/linear/src/impls/impl_rand.rs +++ b/models/linear/src/impls/impl_rand.rs @@ -59,13 +59,13 @@ where let b_dim = bias_dim(self.raw_dim()); Self { bias: Some(Array::uniform_between(b_dim, low, high)), - weights: Array::uniform_between(self.raw_dim(), low, high), + weight: Array::uniform_between(self.raw_dim(), low, high), _mode: self._mode, } } else if !self.is_biased() && self.bias.is_some() { Self { bias: None, - weights: Array::uniform_between(self.raw_dim(), low, high), + weight: Array::uniform_between(self.raw_dim(), low, high), _mode: self._mode, } } else { @@ -74,7 +74,7 @@ where .bias .as_ref() .map(|b| Array::uniform_between(b.raw_dim(), low, high)), - weights: Array::uniform_between(self.raw_dim(), low, high), + weight: Array::uniform_between(self.raw_dim(), low, high), _mode: self._mode, } } diff --git a/models/linear/src/impls/params/impl_from.rs b/models/linear/src/impls/params/impl_from.rs index 18be8c99..ef22d72a 100644 --- a/models/linear/src/impls/params/impl_from.rs +++ b/models/linear/src/impls/params/impl_from.rs @@ -90,7 +90,7 @@ where let bias = ArrayBase::from_elem((), bias); Self { bias: Some(bias), - weights, + weight: weights, _mode: PhantomData, } } @@ -102,7 +102,7 @@ where fn from((weights, bias): (Array1, Option)) -> Self { Self { bias: bias.map(|b| ArrayBase::from_elem((), b)), - weights, + weight: weights, _mode: PhantomData, } } @@ -116,7 +116,7 @@ where fn from((weights, bias): NodeBase) -> Self { Self { bias, - weights, + weight: weights, _mode: PhantomData::, } } @@ -130,7 +130,7 @@ where fn from((weights, bias): Pair, ArrayBase>) -> Self { Self { bias: Some(bias), - weights, + weight: weights, _mode: PhantomData::, } } diff --git a/models/linear/src/impls/params/impl_params.rs b/models/linear/src/impls/params/impl_params.rs index 45919c9d..f3808a22 100644 --- a/models/linear/src/impls/params/impl_params.rs +++ b/models/linear/src/impls/params/impl_params.rs @@ -32,7 +32,7 @@ where { fn clone(&self) -> Self { Self { - weights: self.weights.clone(), + weight: self.weight.clone(), bias: self.bias.clone(), _mode: self._mode, } @@ -55,7 +55,7 @@ where S: Data, { fn eq(&self, other: &Self) -> bool { - self.weights() == other.weights && self.bias == other.bias + self.weights() == other.weight && self.bias == other.bias } } diff --git a/models/linear/src/impls/params/impl_serde.rs b/models/linear/src/impls/params/impl_serde.rs index 6188228b..407c0c69 100644 --- a/models/linear/src/impls/params/impl_serde.rs +++ b/models/linear/src/impls/params/impl_serde.rs @@ -23,7 +23,7 @@ where let (bias, weights) = Deserialize::deserialize(deserializer)?; Ok(Self { bias, - weights, + weight: weights, _mode: PhantomData, }) } diff --git a/models/linear/src/macros.rs b/models/linear/src/macros.rs index 6fad31d9..99ec867a 100644 --- a/models/linear/src/macros.rs +++ b/models/linear/src/macros.rs @@ -3,23 +3,25 @@ Contrib: FL03 */ -macro_rules! impl_param_builder { - ($call:ident where $($rest:tt)*) => { - impl_param_builder!(@impl $call where $($rest)*); +macro_rules! impl_params_builder { + ($method:ident$(.$call:ident)? where $($rest:tt)*) => { + impl_params_builder!(@impl $method$(.$call)? where $($rest)*); + }; + (@impl $method:ident where $($rest:tt)*) => { + impl_params_builder!(@impl $method.$method where $($rest)*); }; - (@impl $call:ident where $($rest:tt)*) => { - pub fn $call(shape: Sh) -> Self + (@impl $method:ident.$call:ident where $($rest:tt)*) => { + pub fn $method(shape: Sh) -> Self where K: $crate::params::mode::ParamMode, Sh: ndarray::ShapeBuilder, $($rest)* { - let shape = shape.into_shape(); - let dim = shape.raw_dim().clone(); + let dim = shape.into_shape().raw_dim().clone(); ParamsBase { bias: build_bias(K::BIASED, dim.clone(), |dim| ndarray::ArrayBase::$call(dim)), - weights: ndarray::ArrayBase::$call(dim), - _mode: core::marker::PhantomData, + weight: ndarray::ArrayBase::$call(dim), + _mode: ::core::marker::PhantomData::, } } }; @@ -39,11 +41,10 @@ macro_rules! impl_model_builder { Sh: ndarray::ShapeBuilder, $($rest)* { - let config = $crate::model::Config::::new().with_shape(shape); - let params = $crate::params::ParamsBase::$call(config.dim()); + let dim = shape.into_shape().raw_dim().clone(); $crate::model::Linear { - config, - params, + config: $crate::model::Config::::new().with_shape(dim.clone()), + params: $crate::params::ParamsBase::$call(dim), } } }; @@ -107,7 +108,7 @@ macro_rules! ndview { (@apply $call:ident($self:expr)$(.$as:ident())?) => { $crate::params::ParamsBase { bias: $self.bias$(.$as())?.map(|arr| arr.$call()), - weights: $self.weights.$call(), + weight: $self.weight.$call(), _mode: $self._mode, } }; diff --git a/models/linear/src/params/store.rs b/models/linear/src/params/store.rs index 586b227f..2c9aea47 100644 --- a/models/linear/src/params/store.rs +++ b/models/linear/src/params/store.rs @@ -2,23 +2,22 @@ Appellation: params Contrib: FL03 */ -use crate::{build_bias, Biased, Features, Node, Unbiased}; +use crate::{build_bias, Biased, Features, Node, ParamMode, Unbiased}; use core::marker::PhantomData; use nd::*; use num::{One, Zero}; -/// The base paramter store for a linear model. -/// -/// [ParamsBase] works to store the weights and biases of a linear model. -/// The structure is parameterized over the type and dimension of the data as well as the current mode of the store. -/// -pub struct ParamsBase, D = Ix2, K = Unbiased> +/// The [ParamsBase] struct is a generic store for linear parameters. The store mimics +/// the underlying [ArrayBase](ndarray::ArrayBase), enabling developers to specify +/// the data repr and dimension. Additionally, the store is parameterized to +/// accept a `K` type, used to designate the store as either [Biased](crate::Biased) or [Unbiased](crate::Unbiased). +pub struct ParamsBase, D = Ix2, K = Biased> where D: Dimension, S: RawData, { pub(crate) bias: Option>, - pub(crate) weights: ArrayBase, + pub(crate) weight: ArrayBase, pub(crate) _mode: PhantomData, } @@ -27,9 +26,32 @@ where D: RemoveAxis, S: RawData, { - impl_param_builder!(default where A: Default, S: DataOwned); - impl_param_builder!(ones where A: Clone + One, S: DataOwned); - impl_param_builder!(zeros where A: Clone + Zero, S: DataOwned); + pub fn from_elem(shape: Sh, elem: A) -> Self + where + A: Clone, + K: ParamMode, + S: DataOwned, + Sh: ShapeBuilder, + { + let dim = shape.into_shape().raw_dim().clone(); + let bias = if K::BIASED { + Some(ArrayBase::from_elem( + crate::bias_dim(dim.clone()), + elem.clone(), + )) + } else { + None + }; + Self { + bias, + weight: ArrayBase::from_elem(dim, elem), + _mode: PhantomData::, + } + } + + impl_params_builder!(default where A: Default, S: DataOwned); + impl_params_builder!(ones where A: Clone + One, S: DataOwned); + impl_params_builder!(zeros where A: Clone + Zero, S: DataOwned); pub fn into_biased(self) -> ParamsBase where @@ -40,14 +62,14 @@ where if self.is_biased() { return ParamsBase { bias: self.bias, - weights: self.weights, + weight: self.weight, _mode: PhantomData::, }; } let sm = crate::bias_dim(self.raw_dim()); ParamsBase { bias: Some(ArrayBase::default(sm)), - weights: self.weights, + weight: self.weight, _mode: PhantomData::, } } @@ -55,17 +77,17 @@ where pub fn into_unbiased(self) -> ParamsBase { ParamsBase { bias: None, - weights: self.weights, + weight: self.weight, _mode: PhantomData::, } } pub const fn weights(&self) -> &ArrayBase { - &self.weights + &self.weight } pub fn weights_mut(&mut self) -> &mut ArrayBase { - &mut self.weights + &mut self.weight } pub fn features(&self) -> Features { @@ -121,7 +143,7 @@ where let dim = shape.into_shape().raw_dim().clone(); Self { bias: build_bias(true, dim.clone(), ArrayBase::default), - weights: ArrayBase::default(dim), + weight: ArrayBase::default(dim), _mode: PhantomData::, } } @@ -149,7 +171,7 @@ where { Self { bias: None, - weights: ArrayBase::default(shape), + weight: ArrayBase::default(shape), _mode: PhantomData::, } } @@ -193,7 +215,7 @@ where fn default() -> Self { Self { bias: Some(Default::default()), - weights: Default::default(), + weight: Default::default(), _mode: PhantomData::, } } @@ -208,7 +230,7 @@ where fn default() -> Self { Self { bias: None, - weights: Default::default(), + weight: Default::default(), _mode: PhantomData::, } } diff --git a/models/linear/src/traits.rs b/models/linear/src/traits.rs index a1b7aef1..5d609a02 100644 --- a/models/linear/src/traits.rs +++ b/models/linear/src/traits.rs @@ -2,6 +2,17 @@ Appellation: traits Contrib: FL03 */ +use crate::Biased; + pub trait IsBiased { fn is_biased(&self) -> bool; } + +impl IsBiased for T +where + T: 'static, +{ + fn is_biased(&self) -> bool { + core::any::TypeId::of::() == core::any::TypeId::of::() + } +} diff --git a/models/linear/tests/norm.rs b/models/linear/tests/norm.rs index 7ed0dfff..119b67d4 100644 --- a/models/linear/tests/norm.rs +++ b/models/linear/tests/norm.rs @@ -7,15 +7,29 @@ extern crate concision_linear as linear; use concision::{linarr, Forward}; use linear::{Biased, LayerNorm}; + +use approx::assert_abs_diff_eq; +use lazy_static::lazy_static; use ndarray::prelude::*; +const SHAPE: (usize, usize) = (3, 3); + +lazy_static! { + static ref NORM: Array2 = array![ + [-0.5492, -0.1619, 0.2254], + [0.6127, 1.0000, 1.3873], + [1.7746, 2.1619, 2.5492], + ]; +} + #[test] fn test_layer_norm() { - let shape = (3, 3); - let x = linarr::(shape).unwrap() + 1f64; + let shape = SHAPE; + let x = linarr::(shape).unwrap(); let ln = LayerNorm::::ones(shape); let y = ln.forward(&x); assert_eq!(y.dim(), shape); + assert_abs_diff_eq!(y, *NORM, epsilon = 1e-4); } diff --git a/models/transformers/src/codec/model.rs b/models/transformers/src/codec/model.rs index 3fb02b6a..494c0a0e 100644 --- a/models/transformers/src/codec/model.rs +++ b/models/transformers/src/codec/model.rs @@ -24,14 +24,11 @@ impl Codec { ); } -builder!( - #[derive(Default)] - CodecBuilder:: { - ctx: Context, - decoder: Decoder, - encoder: Encoder, - } -); +builder!(CodecBuilder:: { + ctx: Context, + decoder: Decoder, + encoder: Encoder, +}); #[derive(Default)] pub struct Generator { From 6ba75d45ad57dca6638b7e73ef67769979bb3868 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Fri, 17 May 2024 15:44:41 -0500 Subject: [PATCH 14/23] update Signed-off-by: Joe McCain III --- core/src/func/activate/nl.rs | 7 +++--- models/transformers/src/attention/head.rs | 28 +++++++++++++++++++++-- models/transformers/src/attention/mod.rs | 28 +++++++++++++++++++++++ 3 files changed, 58 insertions(+), 5 deletions(-) diff --git a/core/src/func/activate/nl.rs b/core/src/func/activate/nl.rs index da1cccce..fb7fefd7 100644 --- a/core/src/func/activate/nl.rs +++ b/core/src/func/activate/nl.rs @@ -4,7 +4,7 @@ */ use ndarray::*; use num::complex::{Complex, ComplexFloat}; -use num::{Float, Zero}; +use num::traits::Zero; pub fn relu(args: &T) -> T where @@ -23,10 +23,11 @@ where (T::one() + (*args).neg().exp()).recip() } -pub fn softmax(args: &Array) -> Array +pub fn softmax(args: &ArrayBase) -> Array where + A: ComplexFloat, D: Dimension, - T: Float, + S: Data, { let denom = args.mapv(|x| x.exp()).sum(); args.mapv(|x| x.exp() / denom) diff --git a/models/transformers/src/attention/head.rs b/models/transformers/src/attention/head.rs index 1b201a8f..b7050249 100644 --- a/models/transformers/src/attention/head.rs +++ b/models/transformers/src/attention/head.rs @@ -4,7 +4,10 @@ */ use crate::params::ParamsBase; use concision::getters; +use nd::linalg::Dot; use nd::*; +use num::complex::ComplexFloat; +use num::traits::FromPrimitive; pub struct AttentionHead, D = Ix2> where @@ -39,8 +42,16 @@ where { Self::from_params(ParamsBase::from_elem(shape, value)) } - /// Returns a reference to the underlying parameters. - pub fn params(&self) -> &ParamsBase { + #[allow(dead_code)] + pub(crate) fn dk(&self) -> A + where + A: FromPrimitive, + { + A::from_usize(self.k().len_of(Axis(1))).unwrap() + } + + /// Returns an immuable reference to the underlying parameters. + pub const fn params(&self) -> &ParamsBase { &self.params } /// Returns a mutable reference to the underlying parameters. @@ -53,3 +64,16 @@ where ndbuilder!(ones() where A: Clone + num::One, S: DataOwned); ndbuilder!(zeros() where A: Clone + num::Zero, S: DataOwned); } + +impl AttentionHead, D> +where + D: Dimension, +{ + pub fn attention(&self) -> Array + where + A: ComplexFloat + ScalarOperand, + Array: Dot, Output = Array>, + { + crate::attention::scaled_dot_product(self.q(), self.k(), self.v()) + } +} diff --git a/models/transformers/src/attention/mod.rs b/models/transformers/src/attention/mod.rs index f449f39e..618dc862 100644 --- a/models/transformers/src/attention/mod.rs +++ b/models/transformers/src/attention/mod.rs @@ -3,6 +3,7 @@ Contrib: FL03 */ pub use self::head::AttentionHead; +pub use self::utils::*; pub(crate) mod head; @@ -10,4 +11,31 @@ pub mod multi; pub(crate) mod prelude { pub use super::head::AttentionHead; + pub use super::utils::*; +} + +pub(crate) mod utils { + use concision::func::activate::softmax; + use nd::linalg::Dot; + use nd::{Array, Axis, Dimension, ScalarOperand}; + use num::complex::ComplexFloat; + + pub fn scaled_dot_product( + q: &Array, + k: &Array, + v: &Array, + ) -> Array + where + A: ComplexFloat + ScalarOperand, + D: Dimension, + Array: Dot, Output = Array>, + { + let qk = q.dot(&k.t().to_owned()); + let scale = { + let dk = A::from(k.len_of(Axis(1))).unwrap(); + dk.sqrt() + }; + let scaled = qk * scale.recip(); + softmax(&scaled).dot(&v) + } } From 857e7d82c64d42bea73dc2bc287e864d5e77aaa4 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sat, 18 May 2024 06:46:06 -0500 Subject: [PATCH 15/23] update Signed-off-by: Joe McCain III --- core/src/func/activate/nl.rs | 13 ++++++++++++ core/src/func/dropout.rs | 24 ++++++++--------------- core/src/math/traits.rs | 10 ++++++---- models/transformers/src/attention/head.rs | 11 ++++++++++- models/transformers/src/attention/mod.rs | 16 +++++++++------ models/transformers/src/params/store.rs | 11 ++++++----- 6 files changed, 53 insertions(+), 32 deletions(-) diff --git a/core/src/func/activate/nl.rs b/core/src/func/activate/nl.rs index fb7fefd7..d451577a 100644 --- a/core/src/func/activate/nl.rs +++ b/core/src/func/activate/nl.rs @@ -143,3 +143,16 @@ nonlinear!( Complex < f64 > ]>, ); + +impl Softmax for ArrayBase +where + A: ComplexFloat, + D: Dimension, + S: Data, +{ + type Output = Array; + + fn softmax(&self) -> Self::Output { + softmax(self) + } +} diff --git a/core/src/func/dropout.rs b/core/src/func/dropout.rs index 969f57f5..3eb901a9 100644 --- a/core/src/func/dropout.rs +++ b/core/src/func/dropout.rs @@ -42,21 +42,18 @@ where unimplemented!() } +/// +/// +/// ### Parameters +/// +/// - (p) Probability of dropping an element pub struct Dropout { - axis: Option, p: f64, } impl Dropout { pub fn new(p: f64) -> Self { - Self { axis: None, p } - } - - pub fn with_axis(self, axis: Axis) -> Self { - Self { - axis: Some(axis), - ..self - } + Self { p } } pub fn dropout(&self, array: &ArrayBase) -> Array @@ -68,13 +65,8 @@ impl Dropout { dropout(array, self.p) } - pub fn dropout_axis(&self, array: &ArrayBase) -> Array - where - A: Num + ScalarOperand, - D: RemoveAxis, - S: DataOwned, - { - dropout_axis(array, self.axis.unwrap(), self.p) + pub fn scale(&self) -> f64 { + (1f64 - self.p).recip() } } diff --git a/core/src/math/traits.rs b/core/src/math/traits.rs index 718b4275..d5af4af3 100644 --- a/core/src/math/traits.rs +++ b/core/src/math/traits.rs @@ -9,6 +9,7 @@ unary!( Abs::abs(self), Cos::cos(self), Cosh::cosh(self), + Exp::exp(self), Sine::sin(self), Sinh::sinh(self), SquareRoot::sqrt(self) @@ -53,10 +54,11 @@ macro_rules! unary_impls { unary_impls!( Abs::abs<[f32, f64]>, - Cosh::cosh<[f32, f64]>, - Cos::cos<[f32, f64]>, - Sinh::sinh<[f32, f64]>, - Sine::sin<[f32, f64]>, + Cosh::cosh<[f32, f64, Complex, Complex]>, + Cos::cos<[f32, f64, Complex, Complex]>, + Exp::exp<[f32, f64, Complex, Complex]>, + Sinh::sinh<[f32, f64, Complex, Complex]>, + Sine::sin<[f32, f64, Complex, Complex]>, SquareRoot::sqrt<[f32, f64]> ); diff --git a/models/transformers/src/attention/head.rs b/models/transformers/src/attention/head.rs index b7050249..faab2c8e 100644 --- a/models/transformers/src/attention/head.rs +++ b/models/transformers/src/attention/head.rs @@ -59,6 +59,14 @@ where &mut self.params } + pub fn qkv(&self) -> (&ArrayBase, &ArrayBase, &ArrayBase) { + self.params().qkv() + } + + pub fn into_qkv(self) -> (ArrayBase, ArrayBase, ArrayBase) { + self.params.into_qkv() + } + getters!(params::<[q, k, v]> => ArrayBase); ndbuilder!(new::default() where A: Default, S: DataOwned); ndbuilder!(ones() where A: Clone + num::One, S: DataOwned); @@ -74,6 +82,7 @@ where A: ComplexFloat + ScalarOperand, Array: Dot, Output = Array>, { - crate::attention::scaled_dot_product(self.q(), self.k(), self.v()) + let (q, k, v) = self.qkv(); + crate::attention::scaled_dot_product(q, k, v) } } diff --git a/models/transformers/src/attention/mod.rs b/models/transformers/src/attention/mod.rs index 618dc862..0fb65d4f 100644 --- a/models/transformers/src/attention/mod.rs +++ b/models/transformers/src/attention/mod.rs @@ -15,11 +15,18 @@ pub(crate) mod prelude { } pub(crate) mod utils { - use concision::func::activate::softmax; + use concision::func::activate::Softmax; use nd::linalg::Dot; use nd::{Array, Axis, Dimension, ScalarOperand}; use num::complex::ComplexFloat; + pub(crate) fn scale_dk(dk: A) -> A + where + A: ComplexFloat + ScalarOperand, + { + dk.sqrt().recip() + } + pub fn scaled_dot_product( q: &Array, k: &Array, @@ -31,11 +38,8 @@ pub(crate) mod utils { Array: Dot, Output = Array>, { let qk = q.dot(&k.t().to_owned()); - let scale = { - let dk = A::from(k.len_of(Axis(1))).unwrap(); - dk.sqrt() - }; + let scale = scale_dk(A::from(k.len_of(Axis(1))).unwrap()); let scaled = qk * scale.recip(); - softmax(&scaled).dot(&v) + scaled.softmax().dot(&v) } } diff --git a/models/transformers/src/params/store.rs b/models/transformers/src/params/store.rs index fcda0e9e..2b107531 100644 --- a/models/transformers/src/params/store.rs +++ b/models/transformers/src/params/store.rs @@ -55,14 +55,15 @@ where (self.q.view(), self.k.view(), self.v.view()) } - /// Consumes the current parameters, returning a three-tuple the Q, K, and V matrices respectivley. - pub fn into_qkv(self) -> (ArrayBase, ArrayBase, ArrayBase) - where - S: DataOwned, - { + /// Consumes the store and returns a three-tuple consisting of the query, key, and value arrays respectively. + pub fn into_qkv(self) -> (ArrayBase, ArrayBase, ArrayBase) { (self.q, self.k, self.v) } + pub fn qkv(&self) -> (&ArrayBase, &ArrayBase, &ArrayBase) { + (&self.q, &self.k, &self.v) + } + ndbuilder!(new::default() where A: Default, S: DataOwned); ndbuilder!(ones() where A: Clone + One, S: DataOwned); ndbuilder!(zeros() where A: Clone + Zero, S: DataOwned); From f74da48796b476af0cc67857cb279f9088fd15a1 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sat, 18 May 2024 07:42:45 -0500 Subject: [PATCH 16/23] update Signed-off-by: Joe McCain III --- .github/ISSUE_TEMPLATE/proposal.md | 9 +- core/src/func/activate.rs | 15 ++- core/src/func/activate/binary.rs | 31 ++++-- core/src/func/activate/nl.rs | 82 ++++++++------- core/src/func/dropout.rs | 4 +- core/src/lib.rs | 2 +- core/src/math/traits.rs | 26 +++++ core/src/params/impls/impl_rand.rs | 4 +- core/src/rand/generate.rs | 99 ------------------- core/src/rand/initialize.rs | 48 +++++++-- core/src/rand/mod.rs | 2 - core/src/traits/predict.rs | 18 ++-- models/linear/src/impls/params/impl_params.rs | 75 +++++++------- models/linear/src/model/linear.rs | 4 +- 14 files changed, 194 insertions(+), 225 deletions(-) delete mode 100644 core/src/rand/generate.rs diff --git a/.github/ISSUE_TEMPLATE/proposal.md b/.github/ISSUE_TEMPLATE/proposal.md index 1d5e2cb0..d7bacdf8 100644 --- a/.github/ISSUE_TEMPLATE/proposal.md +++ b/.github/ISSUE_TEMPLATE/proposal.md @@ -1,7 +1,7 @@ --- -name: Proposal -about: A proposal for a new feature or change -title: 'Proposal:' +name: Improvement Proposal +about: A formal proposal discussing any new features, changes, or improvements to the project. +title: 'CNC-0000:' labels: ['proposal'] projects: ['@FL03/concision:features', '@FL03/concision:roadmap'] assignees: @@ -10,3 +10,6 @@ assignees: --- +### Resources + +- [Google](https://google.com) \ No newline at end of file diff --git a/core/src/func/activate.rs b/core/src/func/activate.rs index fcedda7a..492ed4d2 100644 --- a/core/src/func/activate.rs +++ b/core/src/func/activate.rs @@ -7,23 +7,20 @@ pub use self::{binary::*, nl::*}; pub mod binary; pub mod nl; -pub fn linear(x: &T) -> T -where - T: Clone, -{ - x.clone() +pub fn linear(x: T) -> T { + x } -unary!(LinearActivation::linear(&self)); +unary!(LinearActivation::linear(self)); -impl LinearActivation for T +impl<'a, T> LinearActivation for &'a T where T: Clone, { type Output = T; - fn linear(&self) -> Self::Output { - linear(self) + fn linear(self) -> Self::Output { + self.clone() } } diff --git a/core/src/func/activate/binary.rs b/core/src/func/activate/binary.rs index d02bb434..5b417dc0 100644 --- a/core/src/func/activate/binary.rs +++ b/core/src/func/activate/binary.rs @@ -6,18 +6,18 @@ use nd::{Array, ArrayBase, Data, Dimension}; use num::{One, Zero}; /// -pub fn heavyside(x: &T) -> T +pub fn heavyside(x: T) -> T where T: One + PartialOrd + Zero, { - if x > &T::zero() { + if x > T::zero() { T::one() } else { T::zero() } } -unary!(Heavyside::heavyside(&self),); +unary!(Heavyside::heavyside(self),); macro_rules! impl_heavyside { ($($ty:ty),* $(,)*) => { @@ -27,7 +27,7 @@ macro_rules! impl_heavyside { impl Heavyside for $ty { type Output = $ty; - fn heavyside(&self) -> Self::Output { + fn heavyside(self) -> Self::Output { heavyside(self) } } @@ -36,15 +36,28 @@ macro_rules! impl_heavyside { impl_heavyside!(f32, f64, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize,); -impl Heavyside for ArrayBase +impl Heavyside for ArrayBase where - A: Heavyside, + A: Clone + Heavyside, D: Dimension, S: Data, { - type Output = Array<::Output, D>; + type Output = Array; - fn heavyside(&self) -> Self::Output { - self.map(Heavyside::heavyside) + fn heavyside(self) -> Self::Output { + self.mapv(Heavyside::heavyside) + } +} + +impl<'a, A, B, S, D> Heavyside for &'a ArrayBase +where + A: Clone + Heavyside, + D: Dimension, + S: Data, +{ + type Output = Array; + + fn heavyside(self) -> Self::Output { + self.mapv(Heavyside::heavyside) } } diff --git a/core/src/func/activate/nl.rs b/core/src/func/activate/nl.rs index d451577a..fa1d364b 100644 --- a/core/src/func/activate/nl.rs +++ b/core/src/func/activate/nl.rs @@ -2,53 +2,38 @@ Appellation: sigmoid Contrib: FL03 */ +use crate::math::Exp; use ndarray::*; use num::complex::{Complex, ComplexFloat}; use num::traits::Zero; -pub fn relu(args: &T) -> T +pub fn relu(args: T) -> T where - T: Clone + PartialOrd + Zero, + T: PartialOrd + Zero, { - if args > &T::zero() { - return args.clone(); + if args > T::zero() { + return args; } T::zero() } -pub fn sigmoid(args: &T) -> T +pub fn sigmoid(args: T) -> T where T: ComplexFloat, { - (T::one() + (*args).neg().exp()).recip() + (T::one() + args.neg().exp()).recip() } pub fn softmax(args: &ArrayBase) -> Array where - A: ComplexFloat, + A: ComplexFloat + ScalarOperand, D: Dimension, S: Data, { - let denom = args.mapv(|x| x.exp()).sum(); - args.mapv(|x| x.exp() / denom) + args.exp() / args.exp().sum() } -pub fn softmax_axis(args: &Array, axis: Option) -> Array -where - D: Dimension + RemoveAxis, - T: NdFloat, -{ - let exp = args.mapv(|x| x.exp()); - if let Some(axis) = axis { - let denom = exp.sum_axis(Axis(axis)); - exp / denom - } else { - let denom = exp.sum(); - exp / denom - } -} - -pub fn tanh(args: &T) -> T +pub fn tanh(args: T) -> T where T: ComplexFloat, { @@ -56,10 +41,10 @@ where } unary!( - ReLU::relu(&self), - Sigmoid::sigmoid(&self), - Softmax::softmax(&self), - Tanh::tanh(&self), + ReLU::relu(self), + Sigmoid::sigmoid(self), + Softmax::softmax(self), + Tanh::tanh(self), ); /* @@ -82,7 +67,7 @@ macro_rules! nonlinear { impl $rho for $T { type Output = $T; - fn $call(&self) -> Self::Output { + fn $call(self) -> Self::Output { $call(self) } } @@ -90,7 +75,7 @@ macro_rules! nonlinear { impl<'a> $rho for &'a $T { type Output = $T; - fn $call(&self) -> Self::Output { + fn $call(self) -> Self::Output { $call(*self) } } @@ -105,12 +90,24 @@ macro_rules! nonlinear { { type Output = Array<::Output, D>; - fn $call(&self) -> Self::Output { - self.map($name::$call) + fn $call(self) -> Self::Output { + self.mapv($name::$call) } } - }; + impl<'a, A, S, D> $name for &'a ArrayBase + where + A: Clone + $name, + D: Dimension, + S: Data + { + type Output = Array<::Output, D>; + + fn $call(self) -> Self::Output { + self.mapv($name::$call) + } + } + }; } nonlinear!( @@ -146,13 +143,26 @@ nonlinear!( impl Softmax for ArrayBase where - A: ComplexFloat, + A: ComplexFloat + ScalarOperand, + D: Dimension, + S: Data, +{ + type Output = Array; + + fn softmax(self) -> Self::Output { + softmax(&self) + } +} + +impl<'a, A, S, D> Softmax for &'a ArrayBase +where + A: ComplexFloat + ScalarOperand, D: Dimension, S: Data, { type Output = Array; - fn softmax(&self) -> Self::Output { + fn softmax(self) -> Self::Output { softmax(self) } } diff --git a/core/src/func/dropout.rs b/core/src/func/dropout.rs index 3eb901a9..00b24b13 100644 --- a/core/src/func/dropout.rs +++ b/core/src/func/dropout.rs @@ -42,9 +42,11 @@ where unimplemented!() } +/// The [Dropout] layer is randomly zeroizes inputs with a given probability (`p`). +/// This regularization technique is often used to prevent overfitting. /// /// -/// ### Parameters +/// ### Config /// /// - (p) Probability of dropping an element pub struct Dropout { diff --git a/core/src/lib.rs b/core/src/lib.rs index b45a9a8f..b6c0acb9 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -16,7 +16,7 @@ pub use self::nn::Module; pub use self::{primitives::*, traits::prelude::*, types::prelude::*, utils::prelude::*}; #[cfg(feature = "rand")] -pub use self::rand::{GenerateRandom, RandomExt}; +pub use self::rand::{Initialize, InitializeExt}; #[macro_use] pub(crate) mod macros; diff --git a/core/src/math/traits.rs b/core/src/math/traits.rs index d5af4af3..6d1de631 100644 --- a/core/src/math/traits.rs +++ b/core/src/math/traits.rs @@ -85,3 +85,29 @@ where self.mapv(|x| x.sqrt()) } } + +impl Exp for ArrayBase +where + A: Clone + Exp, + D: Dimension, + S: Data, +{ + type Output = Array; + + fn exp(self) -> Self::Output { + self.mapv(|x| x.exp()) + } +} + +impl<'a, A, S, D> Exp for &'a ArrayBase +where + A: Clone + ComplexFloat, + D: Dimension, + S: Data, +{ + type Output = Array; + + fn exp(self) -> Self::Output { + self.mapv(|x| x.exp()) + } +} diff --git a/core/src/params/impls/impl_rand.rs b/core/src/params/impls/impl_rand.rs index f8681da9..f0aa0642 100644 --- a/core/src/params/impls/impl_rand.rs +++ b/core/src/params/impls/impl_rand.rs @@ -3,7 +3,7 @@ Contrib: FL03 */ use crate::params::Parameter; -use crate::rand::GenerateRandom; +use crate::rand::InitializeExt; use ndarray::{Array, Dimension}; use ndrand::rand_distr::uniform::SampleUniform; use ndrand::rand_distr::{Distribution, StandardNormal}; @@ -17,7 +17,7 @@ where { pub fn init_uniform(mut self, dk: T) -> Self { let dim = self.value.dim(); - self.value = Array::uniform_between(dk, dim); + self.value = Array::uniform(dim, dk); self } } diff --git a/core/src/rand/generate.rs b/core/src/rand/generate.rs deleted file mode 100644 index 7e9b27a3..00000000 --- a/core/src/rand/generate.rs +++ /dev/null @@ -1,99 +0,0 @@ -/* - Appellation: generate - Contrib: FL03 -*/ -use core::ops::Neg; -use ndarray::*; -use ndrand::rand::rngs::StdRng; -use ndrand::rand::{Rng, SeedableRng}; -use ndrand::rand_distr::uniform::{SampleUniform, Uniform}; -use ndrand::rand_distr::{Bernoulli, BernoulliError, Distribution, StandardNormal}; -use ndrand::RandomExt; -use num::traits::real::Real; -use num::traits::Float; - -pub trait GenerateRandom: Sized -where - D: Dimension, -{ - fn rand(dim: Sh, distr: IdS) -> Self - where - IdS: Distribution, - Sh: ShapeBuilder; - - fn rand_using(dim: Sh, distr: IdS, rng: &mut R) -> Self - where - IdS: Distribution, - R: Rng, - Sh: ShapeBuilder; - - fn bernoulli(dim: impl IntoDimension, p: Option) -> Result - where - Bernoulli: Distribution, - { - let dist = Bernoulli::new(p.unwrap_or(0.5))?; - Ok(Self::rand(dim.into_dimension(), dist)) - } - - fn stdnorm(dim: impl IntoDimension) -> Self - where - StandardNormal: Distribution, - { - Self::rand(dim, StandardNormal) - } - - fn normal_from_key(key: u64, dim: impl IntoDimension) -> Self - where - StandardNormal: Distribution, - R: Rng, - { - Self::rand_using( - dim.into_dimension(), - StandardNormal, - &mut StdRng::seed_from_u64(key), - ) - } - /// Generate a random array with values between u(-a, a) where a is the reciprocal of the value at the given axis - fn uniform(axis: usize, dim: impl IntoDimension) -> Self - where - T: Real + SampleUniform, - { - let dim = dim.into_dimension(); - let dk = T::from(dim[axis]).unwrap().recip(); - Self::uniform_between(dk, dim) - } - - fn uniform_between(dk: T, dim: impl IntoDimension) -> Self - where - T: Copy + Neg + SampleUniform, - { - Self::rand(dim, Uniform::new(-dk, dk)) - } -} - -/* - ************ Implementations ************ -*/ -impl GenerateRandom for Array -where - A: Float + SampleUniform, - D: Dimension, - StandardNormal: Distribution, -{ - fn rand(dim: Sh, distr: Dtr) -> Self - where - Dtr: Distribution, - Sh: ShapeBuilder, - { - Self::random(dim, distr) - } - - fn rand_using(dim: Sh, distr: Dtr, rng: &mut R) -> Self - where - Dtr: Distribution, - R: Rng + ?Sized, - Sh: ShapeBuilder, - { - Self::random_using(dim, distr, rng) - } -} diff --git a/core/src/rand/initialize.rs b/core/src/rand/initialize.rs index e3948967..d2327111 100644 --- a/core/src/rand/initialize.rs +++ b/core/src/rand/initialize.rs @@ -19,18 +19,30 @@ where S: RawData, { /// Generate a random array using the given distribution - fn genrand(shape: Sh, distr: Ds) -> ArrayBase + fn rand(shape: Sh, distr: Ds) -> ArrayBase where S: DataOwned, Ds: Distribution, Sh: ShapeBuilder; /// Generate a random array using the given distribution and random number generator - fn genrand_with(shape: Sh, distr: Ds, rng: &mut R) -> ArrayBase + fn rand_with(shape: Sh, distr: Ds, rng: &mut R) -> ArrayBase where R: Rng + ?Sized, S: DataOwned, Ds: Distribution, Sh: ShapeBuilder; + /// Initialize an array with random values using the given distribution and current shape + fn init_rand(self, distr: Ds) -> ArrayBase + where + S: DataOwned, + Ds: Distribution, + Self: Sized; + /// Initialize an array with random values from the current shape using the given distribution and random number generator + fn init_rand_with(self, distr: Ds, rng: &mut R) -> ArrayBase + where + R: Rng + ?Sized, + S: DataOwned, + Ds: Distribution; } /// This trait extends the [Initialize] trait with methods for generating random arrays from various distributions. @@ -46,7 +58,7 @@ where Bernoulli: Distribution, { let dist = Bernoulli::new(p.unwrap_or(0.5))?; - Ok(Self::genrand(shape, dist)) + Ok(Self::rand(shape, dist)) } /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution fn stdnorm(shape: Sh) -> ArrayBase @@ -55,7 +67,7 @@ where Sh: ShapeBuilder, StandardNormal: Distribution, { - Self::genrand(shape, StandardNormal) + Self::rand(shape, StandardNormal) } /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution with a given seed fn stdnorm_from_seed(shape: Sh, seed: u64) -> ArrayBase @@ -64,13 +76,12 @@ where Sh: ShapeBuilder, StandardNormal: Distribution, { - Self::genrand_with( + Self::rand_with( shape, StandardNormal, &mut rngs::StdRng::seed_from_u64(seed), ) } - /// A [uniform](rand_distr::uniform::Uniform) generator with values between u(-dk, dk) fn uniform(shape: Sh, dk: A) -> ArrayBase where @@ -78,7 +89,7 @@ where S: DataOwned, Sh: ShapeBuilder, { - Self::genrand(shape, Uniform::new(dk.clone().neg(), dk)) + Self::rand(shape, Uniform::new(dk.clone().neg(), dk)) } /// Generate a random array with values between u(-a, a) where a is the reciprocal of the value at the given axis fn uniform_along(shape: Sh, axis: usize) -> ArrayBase @@ -98,7 +109,7 @@ where S: DataOwned, Sh: ShapeBuilder, { - Self::genrand(shape, Uniform::new(a, b)) + Self::rand(shape, Uniform::new(a, b)) } } /* @@ -110,7 +121,7 @@ where S: RawData, ArrayBase: RandomExt, { - fn genrand(shape: Sh, distr: Ds) -> ArrayBase + fn rand(shape: Sh, distr: Ds) -> ArrayBase where S: DataOwned, Ds: Distribution, @@ -119,7 +130,7 @@ where Self::random(shape, distr) } - fn genrand_with(shape: Sh, distr: Ds, rng: &mut R) -> ArrayBase + fn rand_with(shape: Sh, distr: Ds, rng: &mut R) -> ArrayBase where R: Rng + ?Sized, S: DataOwned, @@ -128,6 +139,23 @@ where { Self::random_using(shape, distr, rng) } + + fn init_rand(self, distr: Ds) -> ArrayBase + where + S: DataOwned, + Ds: Distribution, + { + Self::rand(self.dim(), distr) + } + + fn init_rand_with(self, distr: Ds, rng: &mut R) -> ArrayBase + where + R: Rng + ?Sized, + S: DataOwned, + Ds: Distribution, + { + Self::rand_with(self.dim(), distr, rng) + } } impl InitializeExt for U diff --git a/core/src/rand/mod.rs b/core/src/rand/mod.rs index 2076b360..58bc5810 100644 --- a/core/src/rand/mod.rs +++ b/core/src/rand/mod.rs @@ -6,7 +6,6 @@ pub use self::prelude::*; -pub(crate) mod generate; pub(crate) mod initialize; pub(crate) mod utils; @@ -26,7 +25,6 @@ pub use rand_distr; pub(crate) mod prelude { #[doc(hidden)] - pub use super::generate::GenerateRandom; pub use super::initialize::{Initialize, InitializeExt}; pub use super::utils::*; } diff --git a/core/src/traits/predict.rs b/core/src/traits/predict.rs index fd294757..ab842988 100644 --- a/core/src/traits/predict.rs +++ b/core/src/traits/predict.rs @@ -7,7 +7,7 @@ use crate::error::PredictError; #[doc(hidden)] pub trait Activate: Forward where - F: Fn(&Self::Output) -> Self::Output, + F: for<'a> Fn(&'a Self::Output) -> Self::Output, { fn activate(&self, args: &T, f: F) -> Self::Output { f(&self.forward(args)) @@ -39,17 +39,17 @@ pub trait Predict { /* ********* Implementations ********* */ -impl Forward for Option +impl Forward for S where - S: Forward, - T: Clone, + S: Predict, { - type Output = T; + type Output = Y; - fn forward(&self, args: &T) -> Self::Output { - match self { - Some(s) => s.forward(args), - None => args.clone(), + fn forward(&self, args: &X) -> Self::Output { + if let Ok(y) = self.predict(args) { + y + } else { + panic!("Error in forward propagation") } } } diff --git a/models/linear/src/impls/params/impl_params.rs b/models/linear/src/impls/params/impl_params.rs index f3808a22..17ae9863 100644 --- a/models/linear/src/impls/params/impl_params.rs +++ b/models/linear/src/impls/params/impl_params.rs @@ -82,51 +82,42 @@ where } } -macro_rules! impl_predict { - ($($name:ident),* $(,)?) => { - $(impl_predict!(@impl $name);)* - }; - (@impl $name:ident) => { - impl Predict for $name - where - A: Dot, Output = B>, - B: for<'a> Add<&'a ArrayBase, Output = B>, - D: RemoveAxis, - S: Data, - T: ComplexFloat, - { - type Output = B; +impl Predict for ParamsBase +where + A: Dot, Output = B>, + B: for<'a> Add<&'a ArrayBase, Output = B>, + D: RemoveAxis, + S: Data, + T: ComplexFloat, +{ + type Output = B; - fn predict(&self, input: &A) -> Result { - let wt = self.weights().t().to_owned(); - let mut res = input.dot(&wt); - if let Some(bias) = self.bias.as_ref() { - res = res + bias; - } - Ok(res) - } + fn predict(&self, input: &A) -> Result { + let wt = self.weights().t().to_owned(); + let mut res = input.dot(&wt); + if let Some(bias) = self.bias.as_ref() { + res = res + bias; } + Ok(res) + } +} - impl<'a, A, B, T, S, D, K> Predict for &'a $name - where - A: Dot, Output = B>, - B: Add<&'a ArrayBase, Output = B>, - D: RemoveAxis, - S: Data, - T: ComplexFloat, - { - type Output = B; +impl<'a, A, B, T, S, D, K> Predict for &'a ParamsBase +where + A: Dot, Output = B>, + B: Add<&'a ArrayBase, Output = B>, + D: RemoveAxis, + S: Data, + T: ComplexFloat, +{ + type Output = B; - fn predict(&self, input: &A) -> Result { - let wt = self.weights().t().to_owned(); - let mut res = input.dot(&wt); - if let Some(bias) = self.bias.as_ref() { - res = res + bias; - } - Ok(res) - } + fn predict(&self, input: &A) -> Result { + let wt = self.weights().t().to_owned(); + let mut res = input.dot(&wt); + if let Some(bias) = self.bias.as_ref() { + res = res + bias; } - }; + Ok(res) + } } - -impl_predict!(ParamsBase); diff --git a/models/linear/src/model/linear.rs b/models/linear/src/model/linear.rs index 55984521..f07d49d7 100644 --- a/models/linear/src/model/linear.rs +++ b/models/linear/src/model/linear.rs @@ -51,10 +51,10 @@ where /// Applies an activcation function onto the prediction of the model. pub fn activate(&self, args: &X, func: F) -> Result where - F: for<'a> Fn(&'a Y) -> Y, + F: Fn(Y) -> Y, Self: Predict, { - Ok(func(&self.predict(args)?)) + Ok(func(self.predict(args)?)) } pub const fn config(&self) -> &Config { From 2ce5c5510c0699b222723027b0197532993837b1 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sat, 18 May 2024 10:58:45 -0500 Subject: [PATCH 17/23] update Signed-off-by: Joe McCain III --- Cargo.toml | 2 +- core/src/error/kinds/predict.rs | 9 +- core/src/error/kinds/shape.rs | 5 +- core/src/lib.rs | 2 - core/src/macros.rs | 6 +- core/src/ops/pad.rs | 7 +- core/src/ops/pad/action.rs | 3 + core/src/primitives.rs | 28 +- core/src/traits/arr/create.rs | 81 +++--- core/src/traits/arr/misc.rs | 33 ++- core/src/traits/arr/ops.rs | 6 +- core/src/traits/arr/shape.rs | 38 --- core/src/traits/arr/tensor.rs | 265 ++++++++++++++++++ core/src/traits/generator.rs | 11 - core/src/traits/mod.rs | 8 +- core/src/types/mod.rs | 8 +- core/src/types/{direction.rs => propagate.rs} | 10 +- data/Cargo.toml | 4 + data/src/lib.rs | 7 +- {core => data}/src/params/impls/impl_rand.rs | 6 +- {core => data}/src/params/kinds.rs | 12 - {core => data}/src/params/mod.rs | 6 +- {core => data}/src/params/parameter.rs | 0 {core => data}/src/params/store.rs | 7 +- {core => data}/tests/params.rs | 6 +- models/linear/src/impls/model/impl_linear.rs | 2 +- models/linear/src/impls/params/impl_from.rs | 2 +- models/linear/src/model/linear.rs | 6 +- models/linear/src/norm/batch/mod.rs | 14 + models/linear/src/norm/batch/model.rs | 6 + models/linear/src/norm/layer/model.rs | 4 +- models/linear/src/norm/mod.rs | 2 + models/linear/src/params/store.rs | 13 +- models/transformers/src/attention/head.rs | 2 +- models/transformers/src/attention/mod.rs | 27 +- models/transformers/src/codec/decoder.rs | 3 + .../transformers/src/codec/decoder/layer.rs | 13 + models/transformers/src/codec/encoder.rs | 3 + .../transformers/src/codec/encoder/layer.rs | 13 + 39 files changed, 490 insertions(+), 190 deletions(-) delete mode 100644 core/src/traits/arr/shape.rs create mode 100644 core/src/traits/arr/tensor.rs delete mode 100644 core/src/traits/generator.rs rename core/src/types/{direction.rs => propagate.rs} (87%) rename {core => data}/src/params/impls/impl_rand.rs (74%) rename {core => data}/src/params/kinds.rs (87%) rename {core => data}/src/params/mod.rs (91%) rename {core => data}/src/params/parameter.rs (100%) rename {core => data}/src/params/store.rs (93%) rename {core => data}/tests/params.rs (85%) create mode 100644 models/linear/src/norm/batch/mod.rs create mode 100644 models/linear/src/norm/batch/model.rs create mode 100644 models/transformers/src/codec/decoder/layer.rs create mode 100644 models/transformers/src/codec/encoder/layer.rs diff --git a/Cargo.toml b/Cargo.toml index 90f493ec..32f3103c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ version = "0.1.14" [workspace.dependencies] # acme = { features = ["full"], branch = "v0.3.2", git = "https://github.com/FL03/acme", version = "0.3.2" } # ndtensor = { features = ["full"], branch = "v0.1.1", git = "https://github.com/FL03/ndtensor", version = "0.1" } -scsys = { default-features = false, branch = "v0.2.3", git = "https://github.com/scattered-systems/scsys.git", version = "0.2" } +scsys = { default-features = false, branch = "v0.2.3", features = ["derive"], git = "https://github.com/scattered-systems/scsys.git", version = "0.2" } approx = "0.5" itertools = "0.13" diff --git a/core/src/error/kinds/predict.rs b/core/src/error/kinds/predict.rs index 114657c1..1e3ede3c 100644 --- a/core/src/error/kinds/predict.rs +++ b/core/src/error/kinds/predict.rs @@ -2,6 +2,7 @@ Appellation: error Contrib: FL03 */ +use scsys::VariantConstructors; use smart_default::SmartDefault; use strum::{AsRefStr, Display, EnumCount, EnumIs, EnumIter, EnumString, VariantNames}; @@ -20,6 +21,7 @@ use strum::{AsRefStr, Display, EnumCount, EnumIs, EnumIter, EnumString, VariantN PartialEq, PartialOrd, SmartDefault, + VariantConstructors, VariantNames, )] #[cfg_attr( @@ -35,10 +37,3 @@ pub enum PredictError { TypeError, } -impl PredictError { - variant_constructor!( - ArithmeticError.arithmetic_error, - ShapeMismatch.shape_mismatch, - TypeError.type_error - ); -} diff --git a/core/src/error/kinds/shape.rs b/core/src/error/kinds/shape.rs index 9e493334..d7c2c831 100644 --- a/core/src/error/kinds/shape.rs +++ b/core/src/error/kinds/shape.rs @@ -27,10 +27,9 @@ use strum::{AsRefStr, Display, EnumCount, EnumIs, EnumIter, EnumString, VariantN )] #[strum(serialize_all = "snake_case")] pub enum ShapeError { - LayoutError, + IncompatibleLayout, + IncompatibleRank, ShapeMismatch, - RankMismatch, SizeMismatch, - Unknown, } diff --git a/core/src/lib.rs b/core/src/lib.rs index b6c0acb9..13004e89 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -27,7 +27,6 @@ pub mod func; pub mod math; pub mod nn; pub mod ops; -pub mod params; #[cfg(feature = "rand")] pub mod rand; pub mod traits; @@ -42,7 +41,6 @@ pub mod prelude { pub use super::math::prelude::*; pub use super::nn::prelude::*; pub use super::ops::prelude::*; - pub use super::params::prelude::*; pub use super::primitives::*; #[cfg(feature = "rand")] pub use super::rand::prelude::*; diff --git a/core/src/macros.rs b/core/src/macros.rs index 33501dfd..df13c3b0 100644 --- a/core/src/macros.rs +++ b/core/src/macros.rs @@ -56,13 +56,13 @@ macro_rules! variant_constructor { variant_constructor!(@loop $($rest),*); )* }; - ($($variant:ident.$method:ident$(($call:expr))?),* $(,)?) => { + ($($variant:ident::$method:ident$(($call:expr))?),* $(,)?) => { $( - variant_constructor!(@loop $variant.$method$(($call))?); + variant_constructor!(@loop $variant::$method$(($call))?); )* }; - (@loop $variant:ident.$method:ident$(($call:expr))?) => { + (@loop $variant:ident::$method:ident$(($call:expr))?) => { pub fn $method() -> Self { Self::$variant$(($call))? } diff --git a/core/src/ops/pad.rs b/core/src/ops/pad.rs index 2b9c4367..e6962770 100644 --- a/core/src/ops/pad.rs +++ b/core/src/ops/pad.rs @@ -67,6 +67,7 @@ impl Padding { } mod utils { + #![cfg(any(feature = "std", feature = "alloc"))] use super::{PadAction, PadMode}; use crate::traits::ArrayLike; use nd::{Array, ArrayBase, AxisDescription, Data, DataOwned, Dimension, Slice}; @@ -77,7 +78,7 @@ mod utils { #[cfg(feature = "std")] use std::borrow::Cow; - fn read_pad(nb_dim: usize, pad: &[[usize; 2]]) -> Cow<[[usize; 2]]> { + fn reader(nb_dim: usize, pad: &[[usize; 2]]) -> Cow<[[usize; 2]]> { if pad.len() == 1 && pad.len() < nb_dim { // The user provided a single padding for all dimensions Cow::from(vec![pad[0]; nb_dim]) @@ -94,7 +95,7 @@ mod utils { D: Dimension, S: DataOwned, { - let pad = read_pad(data.ndim(), pad); + let pad = reader(data.ndim(), pad); let mut new_dim = data.raw_dim(); for (ax, (&ax_len, pad)) in data.shape().iter().zip(pad.iter()).enumerate() { new_dim[ax] = ax_len + pad[0] + pad[1]; @@ -116,7 +117,7 @@ mod utils { D: Dimension, S: Data, { - let pad = read_pad(data.ndim(), pad); + let pad = reader(data.ndim(), pad); // Select portion of padded array that needs to be copied from the original array. output diff --git a/core/src/ops/pad/action.rs b/core/src/ops/pad/action.rs index c3a5c521..1abb60d6 100644 --- a/core/src/ops/pad/action.rs +++ b/core/src/ops/pad/action.rs @@ -2,6 +2,7 @@ Appellation: action Contrib: FL03 */ +use scsys::VariantConstructors; use strum::{AsRefStr, Display, EnumCount, EnumIs, EnumIter, EnumString, VariantNames}; #[derive( @@ -20,6 +21,7 @@ use strum::{AsRefStr, Display, EnumCount, EnumIs, EnumIter, EnumString, VariantN Ord, PartialEq, PartialOrd, + VariantConstructors, VariantNames, )] #[cfg_attr( @@ -37,3 +39,4 @@ pub enum PadAction { StopAfterCopy, Wrapping, } + diff --git a/core/src/primitives.rs b/core/src/primitives.rs index 7f881602..c7f8894b 100644 --- a/core/src/primitives.rs +++ b/core/src/primitives.rs @@ -4,22 +4,24 @@ */ pub use consts::*; -pub mod consts { - pub const DEFAULT_MODEL_SIZE: usize = 2048; - pub const EPSILON: f64 = 1e-8; + + +pub mod consts { + /// The default model size for any given model + pub const D_MODEL: usize = 512; + /// The default epsilon value for floating point operations + pub const EPSILON: f64 = 1e-5; } -#[allow(unused_imports)] +#[allow(unused)] pub(crate) mod rust { pub(crate) use core::*; - - #[cfg(no_std)] + #[cfg(all(feature = "alloc", no_std))] pub(crate) use self::no_std::*; #[cfg(feature = "std")] pub(crate) use self::with_std::*; - - #[cfg(no_std)] + #[cfg(all(feature = "alloc", no_std))] mod no_std { pub use alloc::borrow::Cow; pub use alloc::boxed::{self, Box}; @@ -31,12 +33,14 @@ pub(crate) mod rust { pub use std::borrow::Cow; pub use std::boxed::{self, Box}; pub use std::collections::{self, BTreeMap, BTreeSet, BinaryHeap, VecDeque}; - pub(crate) use std::sync::Arc; + pub use std::sync::Arc; pub use std::vec::{self, Vec}; + + } - #[cfg(no_std)] - pub type Map = collections::BTreeMap; + #[cfg(all(feature = "alloc", no_std))] + pub type Map = alloc::collections::BTreeMap; #[cfg(feature = "std")] - pub type Map = collections::HashMap; + pub type Map = std::collections::HashMap; } diff --git a/core/src/traits/arr/create.rs b/core/src/traits/arr/create.rs index 918cb6a3..b99a5eaa 100644 --- a/core/src/traits/arr/create.rs +++ b/core/src/traits/arr/create.rs @@ -2,10 +2,10 @@ Appellation: create Contrib: FL03 */ -use nd::{ArrayBase, DataOwned, Dimension, RawData, ShapeBuilder}; +use nd::{ArrayBase, DataOwned, Dimension, Ix2, ShapeBuilder}; use num::traits::Num; -pub trait TensorConstructor +pub trait NdLike where Self: DefaultLike + FillLike @@ -14,63 +14,38 @@ where { } -pub trait ArrayLike +pub trait ArrayLike where D: Dimension, - S: RawData, { - fn array_like(&self, shape: Sh, elem: A) -> ArrayBase - where - Sh: ShapeBuilder; -} - -impl ArrayLike for ArrayBase -where - A: Clone, - D: Dimension, - S: nd::DataOwned, -{ - fn array_like(&self, shape: Sh, elem: A) -> ArrayBase - where - Sh: ShapeBuilder, - { - if self.is_standard_layout() { - ArrayBase::from_elem(shape, elem) - } else { - ArrayBase::from_elem(shape.f(), elem) - } - } -} - -pub trait DefaultLike { type Output; - fn default_like(&self) -> Self::Output; + fn array_like(&self, shape: Sh, elem: A) -> Self::Output + where + Sh: ShapeBuilder; } -pub trait FillLike { - type Output; - - fn fill_like(&self, elem: T) -> Self::Output; -} +macro_rules! ndlike { + ($($name:ident::$(<$($T:ident),*>::)?$method:ident $(($($field:ident:$ft:ty),*))?),* $(,)?) => { + $(ndlike!(@impl $name::$(<$($T),*>::)?$method$(($($field:$ft),*))?);)* + }; + (@impl $name:ident::$(<$($T:ident),*>::)?$method:ident$(($($field:ident: $ft:ty),*))?) => { + pub trait $name$(<$($T),*>)? { + type Output; -pub trait OnesLike { - type Output; + fn $method(&self $(, $($field:$ft),*)?) -> Self::Output; + } + }; - fn ones_like(&self) -> Self::Output; } -pub trait ZerosLike { - type Output; - - fn zeros_like(&self) -> Self::Output; -} +ndlike!(DefaultLike::default_like, OnesLike::ones_like, ZerosLike::zeros_like, FillLike::::fill_like(elem: T)); /* ******** implementations ******** */ -impl TensorConstructor> for ArrayBase +impl NdLike> for ArrayBase where A: Clone + Default + Num, D: Dimension, @@ -78,6 +53,26 @@ where { } +impl ArrayLike for ArrayBase +where + A: Clone, + D: Dimension, + S: nd::DataOwned, +{ + type Output = ArrayBase; + + fn array_like(&self, shape: Sh, elem: A) -> Self::Output + where + Sh: ShapeBuilder, + { + if self.is_standard_layout() { + ArrayBase::from_elem(shape, elem) + } else { + ArrayBase::from_elem(shape.f(), elem) + } + } +} + impl FillLike for ArrayBase where A: Clone, diff --git a/core/src/traits/arr/misc.rs b/core/src/traits/arr/misc.rs index 1c5ac006..4cc76e4c 100644 --- a/core/src/traits/arr/misc.rs +++ b/core/src/traits/arr/misc.rs @@ -5,6 +5,16 @@ use nd::Axis; use nd::{ArrayBase, Dimension, RawData}; +pub trait Dimensional { + type Pattern; + + fn dim(&self) -> Self::Pattern; + + fn raw_dim(&self) -> D; + + fn shape(&self) -> &[usize]; +} + pub trait IntoAxis { fn into_axis(self) -> Axis; } @@ -16,6 +26,26 @@ pub trait IsSquare { /* ******** implementations ******** */ +impl Dimensional for ArrayBase +where + D: Dimension, + S: RawData, +{ + type Pattern = D::Pattern; + + fn shape(&self) -> &[usize] { + ArrayBase::shape(self) + } + + fn dim(&self) -> Self::Pattern { + ArrayBase::dim(self) + } + + fn raw_dim(&self) -> D { + ArrayBase::raw_dim(self) + } +} + impl IntoAxis for S where S: AsRef, @@ -31,6 +61,7 @@ where S: RawData, { fn is_square(&self) -> bool { - self.shape().iter().all(|&x| x == self.shape()[0]) + let first = self.shape().first().unwrap(); + self.shape().iter().all(|x| x == first) } } diff --git a/core/src/traits/arr/ops.rs b/core/src/traits/arr/ops.rs index 60d829f7..0072fc4d 100644 --- a/core/src/traits/arr/ops.rs +++ b/core/src/traits/arr/ops.rs @@ -1,6 +1,6 @@ /* - Appellation: arr - Contrib: FL03 + Appellation: ops + Contrib: FL03 */ use nd::linalg::Dot; use nd::*; @@ -31,7 +31,7 @@ pub trait Matpow { } /* - ********* Implementations ********* + ********* Implementations ********* */ impl Affine for Array where diff --git a/core/src/traits/arr/shape.rs b/core/src/traits/arr/shape.rs deleted file mode 100644 index b72e993a..00000000 --- a/core/src/traits/arr/shape.rs +++ /dev/null @@ -1,38 +0,0 @@ -/* - Appellation: shape - Contrib: FL03 -*/ -use nd::{ArrayBase, Dimension, RawData}; - -pub trait Dimensional { - type Pattern; - - fn dim(&self) -> Self::Pattern; - - fn raw_dim(&self) -> D; - - fn shape(&self) -> &[usize]; -} - -/* - ********* Implementations ********* -*/ -impl Dimensional for ArrayBase -where - D: Dimension, - S: RawData, -{ - type Pattern = D::Pattern; - - fn shape(&self) -> &[usize] { - ArrayBase::shape(self) - } - - fn dim(&self) -> Self::Pattern { - ArrayBase::dim(self) - } - - fn raw_dim(&self) -> D { - ArrayBase::raw_dim(self) - } -} diff --git a/core/src/traits/arr/tensor.rs b/core/src/traits/arr/tensor.rs new file mode 100644 index 00000000..7b764643 --- /dev/null +++ b/core/src/traits/arr/tensor.rs @@ -0,0 +1,265 @@ +/* + Appellation: generator + Contrib: FL03 +*/ +use nd::prelude::*; +use nd::{Data, DataMut, DataOwned, OwnedRepr, RawData}; +use num::{One, Zero}; + +/// [NdBuilder] describes common creation routines for [ArrayBase](ndarray::ArrayBase) +pub trait NdBuilder +where + D: Dimension, +{ + type Data: RawData; + type Store; + + /// Create a new array with the given shape whose elements are set to the default value of the element type. + fn default(shape: Sh) -> Self::Store + where + A: Default, + Sh: ShapeBuilder, + Self::Data: DataOwned; + + fn fill(shape: Sh, elem: A) -> Self::Store + where + A: Clone, + Sh: ShapeBuilder, + Self::Data: DataOwned; + + fn ones(shape: Sh) -> Self::Store + where + A: Clone + One, + Sh: ShapeBuilder, + Self::Data: DataOwned; + + fn zeros(shape: Sh) -> Self::Store + where + A: Clone + Zero, + Sh: ShapeBuilder, + Self::Data: DataOwned; +} + +pub trait NdBuilderExt: NdBuilder +where + D: Dimension, +{ + fn dim(&self) -> D::Pattern; + + fn default_like(&self) -> Self::Store + where + A: Default, + Sh: ShapeBuilder, + Self::Data: DataOwned + { + Self::default(self.dim()) + } + + fn fill_like(&self, elem: A) -> Self::Store + where + A: Clone, + Sh: ShapeBuilder, + Self::Data: DataOwned + { + Self::fill(self.dim(), elem) + } + + fn ones_like(&self) -> Self::Store + where + A: Clone + One, + Sh: ShapeBuilder, + Self::Data: DataOwned + { + Self::ones(self.dim()) + } + + fn zeros_like(&self) -> Self::Store + where + A: Clone + Zero, + Sh: ShapeBuilder, + Self::Data: DataOwned + { + Self::zeros(self.dim()) + } +} + +pub trait AsOwned +where + D: Dimension, + S: RawData +{ + type Output; + + fn into_owned(self) -> Self::Output + where + S: Data, + S::Elem: Clone; + + fn to_owned(&self) -> Self::Output + where + S: Data, + S::Elem: Clone; +} + +pub trait AsShared +where + D: Dimension, + S: RawData +{ + type Output; + + fn into_shared(self) -> Self::Output + where + S: DataOwned, + S::Elem: Clone; + + fn to_shared(&self) -> Self::Output + where + S: DataOwned, + S::Elem: Clone; +} + +pub trait NdView, D = Ix2>: AsOwned + AsShared +where + D: Dimension, + S: RawData, +{ + + fn view(&self) -> ArrayView<'_, A, D> + where + A: Clone, + S: Data; + + fn view_mut(&mut self) -> ArrayViewMut<'_, A, D> + where + A: Clone, + S: DataMut; +} + +/* + ************* Implementations ************* +*/ +impl NdBuilder for ArrayBase +where + D: Dimension, + S: RawData, +{ + type Data = S; + type Store = ArrayBase; + + fn default(shape: Sh) -> Self + where + A: Default, + Sh: ShapeBuilder, + Self::Data: DataOwned, + { + ArrayBase::default(shape) + } + + fn fill(shape: Sh, elem: A) -> Self + where + A: Clone, + S: DataOwned, + Sh: ShapeBuilder, + { + ArrayBase::from_elem(shape, elem) + } + + fn ones(shape: Sh) -> Self + where + A: Clone + One, + Sh: ShapeBuilder, + Self::Data: DataOwned, + { + ArrayBase::ones(shape) + } + + fn zeros(shape: Sh) -> Self + where + A: Clone + Zero, + Sh: ShapeBuilder, + Self::Data: DataOwned, + { + ArrayBase::zeros(shape) + } +} + +impl NdBuilderExt for ArrayBase +where + D: Dimension, + S: RawData, +{ + fn dim(&self) -> D::Pattern { + ArrayBase::dim(self) + } +} + +impl AsOwned for ArrayBase +where + D: Dimension, + S: RawData, +{ + type Output = Array; + + fn into_owned(self) -> Self::Output + where + A: Clone, + S: Data, + { + self.into_owned() + } + + fn to_owned(&self) -> Self::Output + where + A: Clone, + S: Data, + { + self.to_owned() + } +} + +impl AsShared for ArrayBase +where + D: Dimension, + S: RawData, +{ + type Output = ArcArray; + + fn into_shared(self) -> Self::Output + where + A: Clone, + S: DataOwned, + { + self.into_shared() + } + + fn to_shared(&self) -> Self::Output + where + A: Clone, + S: DataOwned, + { + self.to_shared() + } +} + +impl NdView for ArrayBase +where + D: Dimension, + S: RawData, +{ + fn view(&self) -> ArrayView<'_, A, D> + where + A: Clone, + S: Data, + { + self.view() + } + + fn view_mut(&mut self) -> ArrayViewMut<'_, A, D> + where + A: Clone, + S: DataMut, + { + self.view_mut() + } +} \ No newline at end of file diff --git a/core/src/traits/generator.rs b/core/src/traits/generator.rs deleted file mode 100644 index 54d058fd..00000000 --- a/core/src/traits/generator.rs +++ /dev/null @@ -1,11 +0,0 @@ -/* - Appellation: generator - Contrib: FL03 -*/ - -/// This trait describes actors that can generate data -pub trait Generative { - type Output; - - fn generate(&self, args: T) -> Self::Output; -} diff --git a/core/src/traits/mod.rs b/core/src/traits/mod.rs index 350e5cfa..db1c9a34 100644 --- a/core/src/traits/mod.rs +++ b/core/src/traits/mod.rs @@ -4,7 +4,6 @@ */ pub use self::prelude::*; -pub mod generator; pub mod num; pub mod ops; pub mod predict; @@ -12,17 +11,17 @@ pub mod train; pub mod arr { pub use self::prelude::*; - + pub(crate) mod create; pub(crate) mod misc; pub(crate) mod ops; - pub(crate) mod shape; + pub(crate) mod tensor; pub(crate) mod prelude { pub use super::create::*; pub use super::misc::*; pub use super::ops::*; - pub use super::shape::*; + pub use super::tensor::*; } } @@ -45,7 +44,6 @@ pub mod misc { pub(crate) mod prelude { pub use super::arr::prelude::*; - pub use super::generator::*; pub use super::misc::prelude::*; pub use super::num::*; pub use super::ops::*; diff --git a/core/src/types/mod.rs b/core/src/types/mod.rs index 655b3673..926a1387 100644 --- a/core/src/types/mod.rs +++ b/core/src/types/mod.rs @@ -6,7 +6,7 @@ pub use self::prelude::*; #[cfg(feature = "std")] pub use self::std_types::*; -pub mod direction; +pub mod propagate; /// A type alias for a [Result](core::result::Result) with the crate's [Error](crate::error::Error) type. /// Defaults to `Result<(), Error>` @@ -14,14 +14,14 @@ pub type Result = core::result::Result; #[cfg(feature = "std")] mod std_types { - /// + /// A type alias for a boxed [Error](std::error::Error) type that is `Send`, `Sync`, and `'static`. pub type BoxError = Box; - /// + /// A type alias for a boxed [Result](core::result::Result) which returns some object, `T`, and uses a [BoxError] as the error type. pub type BoxResult = core::result::Result; } pub(crate) mod prelude { - pub use super::direction::Direction; + pub use super::propagate::Propagate; #[cfg(feature = "std")] pub use super::std_types::*; pub use super::Result; diff --git a/core/src/types/direction.rs b/core/src/types/propagate.rs similarity index 87% rename from core/src/types/direction.rs rename to core/src/types/propagate.rs index 2a86ada7..d765ccc5 100644 --- a/core/src/types/direction.rs +++ b/core/src/types/propagate.rs @@ -29,13 +29,13 @@ use strum::{AsRefStr, Display, EnumCount, EnumIs, EnumIter, EnumString, VariantN serde(rename_all = "lowercase") )] #[strum(serialize_all = "lowercase")] -pub enum Direction { +pub enum Propagate { Backward = 0, #[default] Forward = 1, } -impl Direction { +impl Propagate { /// A functional alias for [Direction::Backward]. pub fn backward() -> Self { Self::Backward @@ -46,13 +46,13 @@ impl Direction { } } -impl From for usize { - fn from(direction: Direction) -> Self { +impl From for usize { + fn from(direction: Propagate) -> Self { direction as usize } } -impl From for Direction { +impl From for Propagate { fn from(index: usize) -> Self { match index % Self::COUNT { 0 => Self::Backward, diff --git a/data/Cargo.toml b/data/Cargo.toml index 0ca51862..e6a1c57f 100644 --- a/data/Cargo.toml +++ b/data/Cargo.toml @@ -80,6 +80,10 @@ crate-type = ["lib"] doctest = false test = true +[[test]] +name = "params" +required-features = ["std"] + [build-dependencies] [dependencies] diff --git a/data/src/lib.rs b/data/src/lib.rs index 31348676..0186ca18 100644 --- a/data/src/lib.rs +++ b/data/src/lib.rs @@ -11,17 +11,20 @@ extern crate alloc; extern crate concision_core as concision; +extern crate ndarray as nd; pub use self::dataset::Dataset; pub use self::traits::prelude::*; pub mod dataset; +pub mod params; #[doc(hidden)] pub mod preproc; pub mod tensor; pub mod traits; pub mod prelude { - pub use crate::dataset::*; - pub use crate::traits::prelude::*; + pub use super::dataset::*; + pub use super::params::prelude::*; + pub use super::traits::prelude::*; } diff --git a/core/src/params/impls/impl_rand.rs b/data/src/params/impls/impl_rand.rs similarity index 74% rename from core/src/params/impls/impl_rand.rs rename to data/src/params/impls/impl_rand.rs index f0aa0642..3461ffcd 100644 --- a/core/src/params/impls/impl_rand.rs +++ b/data/src/params/impls/impl_rand.rs @@ -3,10 +3,10 @@ Contrib: FL03 */ use crate::params::Parameter; -use crate::rand::InitializeExt; +use concision::InitializeExt; use ndarray::{Array, Dimension}; -use ndrand::rand_distr::uniform::SampleUniform; -use ndrand::rand_distr::{Distribution, StandardNormal}; +use concision::rand::rand_distr::uniform::SampleUniform; +use concision::rand::rand_distr::{Distribution, StandardNormal}; use num::Float; impl Parameter diff --git a/core/src/params/kinds.rs b/data/src/params/kinds.rs similarity index 87% rename from core/src/params/kinds.rs rename to data/src/params/kinds.rs index 70a71885..c676331b 100644 --- a/core/src/params/kinds.rs +++ b/data/src/params/kinds.rs @@ -4,18 +4,6 @@ */ use strum::{AsRefStr, EnumCount, EnumIs, EnumIter, EnumString, VariantNames}; -pub trait ParamType: ToString { - fn kind(&self) -> String; -} - -impl ParamType for T -where - T: ToString, -{ - fn kind(&self) -> String { - self.to_string() - } -} #[derive( AsRefStr, diff --git a/core/src/params/mod.rs b/data/src/params/mod.rs similarity index 91% rename from core/src/params/mod.rs rename to data/src/params/mod.rs index 5ec083bf..8b35b030 100644 --- a/core/src/params/mod.rs +++ b/data/src/params/mod.rs @@ -37,9 +37,9 @@ pub(crate) mod prelude { #[cfg(test)] mod tests { use super::*; - use crate::linarr; - use ndarray::linalg::Dot; - use ndarray::prelude::{Ix1, Ix2}; + use concision::linarr; + use nd::linalg::Dot; + use nd::prelude::*; #[test] fn test_parameter() { diff --git a/core/src/params/parameter.rs b/data/src/params/parameter.rs similarity index 100% rename from core/src/params/parameter.rs rename to data/src/params/parameter.rs diff --git a/core/src/params/store.rs b/data/src/params/store.rs similarity index 93% rename from core/src/params/store.rs rename to data/src/params/store.rs index 6428c985..cf451e37 100644 --- a/core/src/params/store.rs +++ b/data/src/params/store.rs @@ -2,11 +2,16 @@ Appellation: store Contrib: FL03 */ +#![cfg(any(feature = "alloc", feature = "std"))] use super::{ParamKind, Parameter}; -use crate::prelude::Map; use ndarray::prelude::{Dimension, Ix2}; use num::Float; +#[cfg(all(feature = "alloc", no_std))] +use alloc::collections::BTreeMap as Map; +#[cfg(feature = "std")] +use std::collections::HashMap as Map; + #[derive(Clone, Debug, Default, Eq, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct ParamStore diff --git a/core/tests/params.rs b/data/tests/params.rs similarity index 85% rename from core/tests/params.rs rename to data/tests/params.rs index e121f607..691f41c7 100644 --- a/core/tests/params.rs +++ b/data/tests/params.rs @@ -2,9 +2,11 @@ Appellation: params Contrib: FL03 */ -extern crate concision_core as cnc; +extern crate concision_core as concision; +extern crate concision_data as data; -use cnc::prelude::{linarr, ParamKind, Parameter}; +use concision::linarr; +use data::params::{Parameter, ParamKind}; use ndarray::linalg::Dot; use ndarray::prelude::*; diff --git a/models/linear/src/impls/model/impl_linear.rs b/models/linear/src/impls/model/impl_linear.rs index e189c639..03c97f82 100644 --- a/models/linear/src/impls/model/impl_linear.rs +++ b/models/linear/src/impls/model/impl_linear.rs @@ -15,7 +15,7 @@ where A: Clone + Default, { let config = Config::std(inputs, outputs); - let params = LinearParams::default(config.dim()); + let params = LinearParams::new(config.dim()); Self { config, params } } } diff --git a/models/linear/src/impls/params/impl_from.rs b/models/linear/src/impls/params/impl_from.rs index ef22d72a..108d7bd3 100644 --- a/models/linear/src/impls/params/impl_from.rs +++ b/models/linear/src/impls/params/impl_from.rs @@ -61,7 +61,7 @@ where let mut iter = nodes.iter(); let node = iter.next().unwrap(); let shape = Features::new(node.0.len(), nodes.len()); - let mut params = ParamsBase::default(shape); + let mut params = ParamsBase::new(shape); params.set_node(0, node.clone()); for (i, node) in iter.into_iter().enumerate() { params.set_node(i + 1, node.clone()); diff --git a/models/linear/src/model/linear.rs b/models/linear/src/model/linear.rs index f07d49d7..0e6a5a79 100644 --- a/models/linear/src/model/linear.rs +++ b/models/linear/src/model/linear.rs @@ -25,7 +25,7 @@ impl Linear where D: RemoveAxis, { - impl_model_builder!(default where A: Default); + impl_model_builder!(new where A: Default); impl_model_builder!(ones where A: Clone + num::One); impl_model_builder!(zeros where A: Clone + num::Zero); @@ -34,7 +34,7 @@ where A: Clone + Default, K: ParamMode, { - let params = LinearParams::default(config.dim()); + let params = LinearParams::new(config.dim()); Self { config, params } } @@ -44,7 +44,7 @@ where K: ParamMode, { let config = Config::::new().with_layout(layout); - let params = LinearParams::default(config.dim()); + let params = LinearParams::new(config.dim()); Self { config, params } } diff --git a/models/linear/src/norm/batch/mod.rs b/models/linear/src/norm/batch/mod.rs new file mode 100644 index 00000000..35c46ebd --- /dev/null +++ b/models/linear/src/norm/batch/mod.rs @@ -0,0 +1,14 @@ +/* + Appellation: batch + Contrib: FL03 +*/ +//! # Batch Normalization +//! +//! +pub use self::model::*; + +mod model; + +pub(crate) mod prelude { + pub use super::BatchNorm; +} diff --git a/models/linear/src/norm/batch/model.rs b/models/linear/src/norm/batch/model.rs new file mode 100644 index 00000000..ab33e536 --- /dev/null +++ b/models/linear/src/norm/batch/model.rs @@ -0,0 +1,6 @@ +/* + Appellation: model + Contrib: FL03 +*/ + +pub struct BatchNorm; diff --git a/models/linear/src/norm/layer/model.rs b/models/linear/src/norm/layer/model.rs index 979802b6..8895f083 100644 --- a/models/linear/src/norm/layer/model.rs +++ b/models/linear/src/norm/layer/model.rs @@ -34,7 +34,7 @@ where where A: Default, { - let params = LinearParams::::default(config.dim()); + let params = LinearParams::::new(config.dim()); Self { config, params } } @@ -45,7 +45,7 @@ where { let dim = shape.into_shape().raw_dim().clone(); let config = Config::new().dim(dim.clone()).build(); - let params = LinearParams::::default(dim); + let params = LinearParams::::new(dim); Self { config, params } } diff --git a/models/linear/src/norm/mod.rs b/models/linear/src/norm/mod.rs index d8d607dd..5fa9972d 100644 --- a/models/linear/src/norm/mod.rs +++ b/models/linear/src/norm/mod.rs @@ -7,8 +7,10 @@ //! pub use self::layer::LayerNorm; +pub mod batch; pub mod layer; pub(crate) mod prelude { + pub use super::batch::prelude::*; pub use super::layer::prelude::*; } diff --git a/models/linear/src/params/store.rs b/models/linear/src/params/store.rs index 2c9aea47..90e2809c 100644 --- a/models/linear/src/params/store.rs +++ b/models/linear/src/params/store.rs @@ -3,6 +3,7 @@ Contrib: FL03 */ use crate::{build_bias, Biased, Features, Node, ParamMode, Unbiased}; +use concision::dimensional; use core::marker::PhantomData; use nd::*; use num::{One, Zero}; @@ -49,10 +50,6 @@ where } } - impl_params_builder!(default where A: Default, S: DataOwned); - impl_params_builder!(ones where A: Clone + One, S: DataOwned); - impl_params_builder!(zeros where A: Clone + Zero, S: DataOwned); - pub fn into_biased(self) -> ParamsBase where A: Default, @@ -113,7 +110,13 @@ where crate::is_biased::() } - concision::dimensional!(weights()); + impl_params_builder!(new.default where A: Default, S: DataOwned); + + impl_params_builder!(ones where A: Clone + One, S: DataOwned); + + impl_params_builder!(zeros where A: Clone + Zero, S: DataOwned); + + dimensional!(weights()); ndview!(into_owned::(self) where A: Clone, S: Data); diff --git a/models/transformers/src/attention/head.rs b/models/transformers/src/attention/head.rs index faab2c8e..8eca2693 100644 --- a/models/transformers/src/attention/head.rs +++ b/models/transformers/src/attention/head.rs @@ -83,6 +83,6 @@ where Array: Dot, Output = Array>, { let (q, k, v) = self.qkv(); - crate::attention::scaled_dot_product(q, k, v) + crate::attention::scaled_dot_product_attention(q, k, v) } } diff --git a/models/transformers/src/attention/mod.rs b/models/transformers/src/attention/mod.rs index 0fb65d4f..d420dc3f 100644 --- a/models/transformers/src/attention/mod.rs +++ b/models/transformers/src/attention/mod.rs @@ -17,29 +17,30 @@ pub(crate) mod prelude { pub(crate) mod utils { use concision::func::activate::Softmax; use nd::linalg::Dot; - use nd::{Array, Axis, Dimension, ScalarOperand}; + use nd::{Data, ScalarOperand}; + use nd::prelude::{Array, ArrayBase, Axis, Dimension}; use num::complex::ComplexFloat; - pub(crate) fn scale_dk(dk: A) -> A + pub(crate) fn scale(dk: usize) -> A where - A: ComplexFloat + ScalarOperand, + A: ComplexFloat, { - dk.sqrt().recip() + A::from(dk).unwrap().sqrt().recip() } - pub fn scaled_dot_product( - q: &Array, - k: &Array, - v: &Array, + pub fn scaled_dot_product_attention( + q: &ArrayBase, + k: &ArrayBase, + v: &ArrayBase, ) -> Array where A: ComplexFloat + ScalarOperand, + S: Data, D: Dimension, - Array: Dot, Output = Array>, + ArrayBase: Dot, Output = Array>, + Array: Dot, Output = Array> { - let qk = q.dot(&k.t().to_owned()); - let scale = scale_dk(A::from(k.len_of(Axis(1))).unwrap()); - let scaled = qk * scale.recip(); - scaled.softmax().dot(&v) + let dk = scale::(k.len_of(Axis(1))); + (q.dot(&k.t().to_owned()) * dk).softmax().dot(&v) } } diff --git a/models/transformers/src/codec/decoder.rs b/models/transformers/src/codec/decoder.rs index 019e5e94..b8c8dd6a 100644 --- a/models/transformers/src/codec/decoder.rs +++ b/models/transformers/src/codec/decoder.rs @@ -2,6 +2,9 @@ Appellation: decoder Contrib: FL03 */ +pub use self::layer::DecoderLayer; + +pub mod layer; #[derive(Default)] pub struct Decoder {} diff --git a/models/transformers/src/codec/decoder/layer.rs b/models/transformers/src/codec/decoder/layer.rs new file mode 100644 index 00000000..dd309d44 --- /dev/null +++ b/models/transformers/src/codec/decoder/layer.rs @@ -0,0 +1,13 @@ +/* + Appellation: layer + Contrib: FL03 +*/ + +#[derive(Default)] +pub struct DecoderLayer {} + +impl DecoderLayer { + pub fn new() -> Self { + Self {} + } +} \ No newline at end of file diff --git a/models/transformers/src/codec/encoder.rs b/models/transformers/src/codec/encoder.rs index ba63c02b..a9e6c4ad 100644 --- a/models/transformers/src/codec/encoder.rs +++ b/models/transformers/src/codec/encoder.rs @@ -2,6 +2,9 @@ Appellation: encoder Contrib: FL03 */ +pub use self::layer::EncoderLayer; + +pub mod layer; #[derive(Default)] pub struct Encoder {} diff --git a/models/transformers/src/codec/encoder/layer.rs b/models/transformers/src/codec/encoder/layer.rs new file mode 100644 index 00000000..2565a395 --- /dev/null +++ b/models/transformers/src/codec/encoder/layer.rs @@ -0,0 +1,13 @@ +/* + Appellation: layer + Contrib: FL03 +*/ + +#[derive(Default)] +pub struct EncoderLayer {} + +impl EncoderLayer { + pub fn new() -> Self { + Self {} + } +} \ No newline at end of file From 279c5215de8bd98406295adb9df1ee0c1ace454b Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sat, 18 May 2024 11:37:25 -0500 Subject: [PATCH 18/23] update Signed-off-by: Joe McCain III --- core/src/error/kinds/predict.rs | 1 - core/src/func/loss/entropy.rs | 4 ++ core/src/func/{loss.rs => loss/mod.rs} | 4 ++ core/src/func/loss/reg.rs | 13 ++++ core/src/func/loss/reg/avg.rs | 64 +++++++++++++++++++ core/src/math/traits.rs | 39 +++++++++++ core/src/ops/pad/action.rs | 1 - core/src/primitives.rs | 7 +- core/src/traits/arr/ops.rs | 2 +- core/src/traits/arr/tensor.rs | 17 +++-- core/src/traits/mod.rs | 2 +- data/src/params/impls/impl_rand.rs | 4 +- data/src/params/kinds.rs | 1 - data/tests/params.rs | 2 +- .../linear/src/model/{linear.rs => layer.rs} | 0 models/linear/src/model/mod.rs | 6 +- models/transformers/src/attention/mod.rs | 4 +- .../transformers/src/codec/decoder/layer.rs | 2 +- .../transformers/src/codec/encoder/layer.rs | 2 +- 19 files changed, 145 insertions(+), 30 deletions(-) create mode 100644 core/src/func/loss/entropy.rs rename core/src/func/{loss.rs => loss/mod.rs} (76%) create mode 100644 core/src/func/loss/reg.rs create mode 100644 core/src/func/loss/reg/avg.rs rename models/linear/src/model/{linear.rs => layer.rs} (100%) diff --git a/core/src/error/kinds/predict.rs b/core/src/error/kinds/predict.rs index 1e3ede3c..de39d1d8 100644 --- a/core/src/error/kinds/predict.rs +++ b/core/src/error/kinds/predict.rs @@ -36,4 +36,3 @@ pub enum PredictError { ShapeMismatch, TypeError, } - diff --git a/core/src/func/loss/entropy.rs b/core/src/func/loss/entropy.rs new file mode 100644 index 00000000..1982a20f --- /dev/null +++ b/core/src/func/loss/entropy.rs @@ -0,0 +1,4 @@ +/* + Appellation: entropy + Contrib: FL03 +*/ diff --git a/core/src/func/loss.rs b/core/src/func/loss/mod.rs similarity index 76% rename from core/src/func/loss.rs rename to core/src/func/loss/mod.rs index aed784ab..b88f91f5 100644 --- a/core/src/func/loss.rs +++ b/core/src/func/loss/mod.rs @@ -3,7 +3,11 @@ Contrib: FL03 */ +pub mod entropy; +pub mod reg; + pub(crate) mod prelude { + pub use super::reg::prelude::*; pub use super::Loss; } diff --git a/core/src/func/loss/reg.rs b/core/src/func/loss/reg.rs new file mode 100644 index 00000000..679d8a6d --- /dev/null +++ b/core/src/func/loss/reg.rs @@ -0,0 +1,13 @@ +/* + Appellation: reg + Contrib: FL03 +*/ +//! # Regressive Loss Functions +//! +//! + +pub mod avg; + +pub(crate) mod prelude { + pub use super::avg::*; +} diff --git a/core/src/func/loss/reg/avg.rs b/core/src/func/loss/reg/avg.rs new file mode 100644 index 00000000..163ef665 --- /dev/null +++ b/core/src/func/loss/reg/avg.rs @@ -0,0 +1,64 @@ +/* + Appellation: avg + Contrib: FL03 +*/ +use crate::math::{Abs, Squared}; +use nd::prelude::*; +use nd::{Data, ScalarOperand}; +use num::traits::{FromPrimitive, Num, Pow, Signed}; + +pub fn mae(pred: &ArrayBase, target: &ArrayBase) -> Option +where + A: FromPrimitive + Num + ScalarOperand + Signed, + D: Dimension, + S: Data, +{ + (pred - target).abs().mean() +} + +pub fn mse(pred: &ArrayBase, target: &ArrayBase) -> Option +where + A: FromPrimitive + Num + Pow + ScalarOperand, + D: Dimension, + S: Data, +{ + (pred - target).sqrd().mean() +} + +pub trait MeanAbsoluteError { + type Output; + + fn mae(&self, target: &Rhs) -> Self::Output; +} + +pub trait MeanSquaredError { + type Output; + + fn mse(&self, target: &Rhs) -> Self::Output; +} + +impl MeanAbsoluteError> for ArrayBase +where + A: FromPrimitive + Num + ScalarOperand + Signed, + D: Dimension, + S: Data, +{ + type Output = Option; + + fn mae(&self, target: &ArrayBase) -> Self::Output { + (self - target).abs().mean() + } +} + +impl MeanSquaredError> for ArrayBase +where + A: FromPrimitive + Num + Pow + ScalarOperand, + D: Dimension, + S: Data, +{ + type Output = Option; + + fn mse(&self, target: &ArrayBase) -> Self::Output { + (self - target).sqrd().mean() + } +} diff --git a/core/src/math/traits.rs b/core/src/math/traits.rs index 6d1de631..d71d433d 100644 --- a/core/src/math/traits.rs +++ b/core/src/math/traits.rs @@ -4,6 +4,7 @@ */ use nd::{Array, ArrayBase, Data, Dimension}; use num::complex::{Complex, ComplexFloat}; +use num::traits::Signed; unary!( Abs::abs(self), @@ -12,6 +13,7 @@ unary!( Exp::exp(self), Sine::sin(self), Sinh::sinh(self), + Squared::sqrd(self), SquareRoot::sqrt(self) ); @@ -62,6 +64,43 @@ unary_impls!( SquareRoot::sqrt<[f32, f64]> ); +impl Abs for ArrayBase +where + A: Clone + Signed, + D: Dimension, + S: Data, +{ + type Output = Array; + + fn abs(self) -> Self::Output { + self.mapv(|x| x.abs()) + } +} + +impl<'a, A, S, D> Abs for &'a ArrayBase +where + A: Clone + Signed, + D: Dimension, + S: Data, +{ + type Output = Array; + + fn abs(self) -> Self::Output { + self.mapv(|x| x.abs()) + } +} + +impl Squared for A +where + A: Clone + core::ops::Mul, +{ + type Output = A; + + fn sqrd(self) -> Self::Output { + self.clone() * self + } +} + impl SquareRoot for Complex where Complex: ComplexFloat, diff --git a/core/src/ops/pad/action.rs b/core/src/ops/pad/action.rs index 1abb60d6..15325ddc 100644 --- a/core/src/ops/pad/action.rs +++ b/core/src/ops/pad/action.rs @@ -39,4 +39,3 @@ pub enum PadAction { StopAfterCopy, Wrapping, } - diff --git a/core/src/primitives.rs b/core/src/primitives.rs index c7f8894b..c2755083 100644 --- a/core/src/primitives.rs +++ b/core/src/primitives.rs @@ -4,9 +4,6 @@ */ pub use consts::*; - - - pub mod consts { /// The default model size for any given model pub const D_MODEL: usize = 512; @@ -16,11 +13,11 @@ pub mod consts { #[allow(unused)] pub(crate) mod rust { - pub(crate) use core::*; #[cfg(all(feature = "alloc", no_std))] pub(crate) use self::no_std::*; #[cfg(feature = "std")] pub(crate) use self::with_std::*; + pub(crate) use core::*; #[cfg(all(feature = "alloc", no_std))] mod no_std { pub use alloc::borrow::Cow; @@ -35,8 +32,6 @@ pub(crate) mod rust { pub use std::collections::{self, BTreeMap, BTreeSet, BinaryHeap, VecDeque}; pub use std::sync::Arc; pub use std::vec::{self, Vec}; - - } #[cfg(all(feature = "alloc", no_std))] diff --git a/core/src/traits/arr/ops.rs b/core/src/traits/arr/ops.rs index 0072fc4d..5fc97c9e 100644 --- a/core/src/traits/arr/ops.rs +++ b/core/src/traits/arr/ops.rs @@ -31,7 +31,7 @@ pub trait Matpow { } /* - ********* Implementations ********* + ********* Implementations ********* */ impl Affine for Array where diff --git a/core/src/traits/arr/tensor.rs b/core/src/traits/arr/tensor.rs index 7b764643..22a00f99 100644 --- a/core/src/traits/arr/tensor.rs +++ b/core/src/traits/arr/tensor.rs @@ -50,7 +50,7 @@ where where A: Default, Sh: ShapeBuilder, - Self::Data: DataOwned + Self::Data: DataOwned, { Self::default(self.dim()) } @@ -59,7 +59,7 @@ where where A: Clone, Sh: ShapeBuilder, - Self::Data: DataOwned + Self::Data: DataOwned, { Self::fill(self.dim(), elem) } @@ -68,7 +68,7 @@ where where A: Clone + One, Sh: ShapeBuilder, - Self::Data: DataOwned + Self::Data: DataOwned, { Self::ones(self.dim()) } @@ -77,7 +77,7 @@ where where A: Clone + Zero, Sh: ShapeBuilder, - Self::Data: DataOwned + Self::Data: DataOwned, { Self::zeros(self.dim()) } @@ -86,7 +86,7 @@ where pub trait AsOwned where D: Dimension, - S: RawData + S: RawData, { type Output; @@ -104,7 +104,7 @@ where pub trait AsShared where D: Dimension, - S: RawData + S: RawData, { type Output; @@ -124,7 +124,6 @@ where D: Dimension, S: RawData, { - fn view(&self) -> ArrayView<'_, A, D> where A: Clone, @@ -137,7 +136,7 @@ where } /* - ************* Implementations ************* + ************* Implementations ************* */ impl NdBuilder for ArrayBase where @@ -262,4 +261,4 @@ where { self.view_mut() } -} \ No newline at end of file +} diff --git a/core/src/traits/mod.rs b/core/src/traits/mod.rs index db1c9a34..b6aa6b21 100644 --- a/core/src/traits/mod.rs +++ b/core/src/traits/mod.rs @@ -11,7 +11,7 @@ pub mod train; pub mod arr { pub use self::prelude::*; - + pub(crate) mod create; pub(crate) mod misc; pub(crate) mod ops; diff --git a/data/src/params/impls/impl_rand.rs b/data/src/params/impls/impl_rand.rs index 3461ffcd..4293d8a1 100644 --- a/data/src/params/impls/impl_rand.rs +++ b/data/src/params/impls/impl_rand.rs @@ -3,10 +3,10 @@ Contrib: FL03 */ use crate::params::Parameter; -use concision::InitializeExt; -use ndarray::{Array, Dimension}; use concision::rand::rand_distr::uniform::SampleUniform; use concision::rand::rand_distr::{Distribution, StandardNormal}; +use concision::InitializeExt; +use ndarray::{Array, Dimension}; use num::Float; impl Parameter diff --git a/data/src/params/kinds.rs b/data/src/params/kinds.rs index c676331b..b9b92225 100644 --- a/data/src/params/kinds.rs +++ b/data/src/params/kinds.rs @@ -4,7 +4,6 @@ */ use strum::{AsRefStr, EnumCount, EnumIs, EnumIter, EnumString, VariantNames}; - #[derive( AsRefStr, Clone, diff --git a/data/tests/params.rs b/data/tests/params.rs index 691f41c7..134ff343 100644 --- a/data/tests/params.rs +++ b/data/tests/params.rs @@ -6,7 +6,7 @@ extern crate concision_core as concision; extern crate concision_data as data; use concision::linarr; -use data::params::{Parameter, ParamKind}; +use data::params::{ParamKind, Parameter}; use ndarray::linalg::Dot; use ndarray::prelude::*; diff --git a/models/linear/src/model/linear.rs b/models/linear/src/model/layer.rs similarity index 100% rename from models/linear/src/model/linear.rs rename to models/linear/src/model/layer.rs diff --git a/models/linear/src/model/mod.rs b/models/linear/src/model/mod.rs index f511cb77..66fdce93 100644 --- a/models/linear/src/model/mod.rs +++ b/models/linear/src/model/mod.rs @@ -3,9 +3,9 @@ Contrib: FL03 */ pub use self::layout::prelude::*; -pub use self::{config::Config, linear::Linear}; +pub use self::{config::Config, layer::Linear}; -mod linear; +mod layer; pub mod config; @@ -22,5 +22,5 @@ pub mod layout { } pub(crate) mod prelude { - pub use super::linear::Linear; + pub use super::layer::Linear; } diff --git a/models/transformers/src/attention/mod.rs b/models/transformers/src/attention/mod.rs index d420dc3f..e317394d 100644 --- a/models/transformers/src/attention/mod.rs +++ b/models/transformers/src/attention/mod.rs @@ -17,8 +17,8 @@ pub(crate) mod prelude { pub(crate) mod utils { use concision::func::activate::Softmax; use nd::linalg::Dot; - use nd::{Data, ScalarOperand}; use nd::prelude::{Array, ArrayBase, Axis, Dimension}; + use nd::{Data, ScalarOperand}; use num::complex::ComplexFloat; pub(crate) fn scale(dk: usize) -> A @@ -38,7 +38,7 @@ pub(crate) mod utils { S: Data, D: Dimension, ArrayBase: Dot, Output = Array>, - Array: Dot, Output = Array> + Array: Dot, Output = Array>, { let dk = scale::(k.len_of(Axis(1))); (q.dot(&k.t().to_owned()) * dk).softmax().dot(&v) diff --git a/models/transformers/src/codec/decoder/layer.rs b/models/transformers/src/codec/decoder/layer.rs index dd309d44..90514acc 100644 --- a/models/transformers/src/codec/decoder/layer.rs +++ b/models/transformers/src/codec/decoder/layer.rs @@ -10,4 +10,4 @@ impl DecoderLayer { pub fn new() -> Self { Self {} } -} \ No newline at end of file +} diff --git a/models/transformers/src/codec/encoder/layer.rs b/models/transformers/src/codec/encoder/layer.rs index 2565a395..10821bd3 100644 --- a/models/transformers/src/codec/encoder/layer.rs +++ b/models/transformers/src/codec/encoder/layer.rs @@ -10,4 +10,4 @@ impl EncoderLayer { pub fn new() -> Self { Self {} } -} \ No newline at end of file +} From 1087c67096fc5960101ea429236d14691b30001a Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sat, 18 May 2024 14:05:45 -0500 Subject: [PATCH 19/23] update Signed-off-by: Joe McCain III --- core/src/error/kinds/external.rs | 2 +- .../src/func/{activate.rs => activate/mod.rs} | 4 +- core/src/lib.rs | 1 + core/src/macros.rs | 208 +----------------- core/src/macros/builder.rs | 47 ++++ core/src/macros/enums.rs | 44 ++++ core/src/macros/getters.rs | 47 ++++ core/src/macros/ops.rs | 36 +++ core/src/macros/toggle.rs | 19 ++ core/src/ops/fft/{cmp.rs => cmp/direction.rs} | 6 +- core/src/ops/fft/cmp/mode.rs | 44 ++++ core/src/ops/fft/{ => cmp}/plan.rs | 67 +++--- core/src/ops/fft/fft.rs | 8 +- core/src/ops/fft/mod.rs | 30 ++- core/src/ops/fft/utils.rs | 27 +++ core/src/rand/initialize.rs | 68 +++--- core/src/rand/mod.rs | 2 - core/src/traits/misc/toggle.rs | 11 +- core/tests/fft.rs | 25 --- core/tests/func.rs | 10 +- data/src/params/impls/impl_rand.rs | 8 +- models/linear/src/impls/impl_rand.rs | 141 ++++++++++-- models/linear/src/model/layer.rs | 7 + 23 files changed, 532 insertions(+), 330 deletions(-) rename core/src/func/{activate.rs => activate/mod.rs} (87%) create mode 100644 core/src/macros/builder.rs create mode 100644 core/src/macros/enums.rs create mode 100644 core/src/macros/getters.rs create mode 100644 core/src/macros/ops.rs create mode 100644 core/src/macros/toggle.rs rename core/src/ops/fft/{cmp.rs => cmp/direction.rs} (92%) create mode 100644 core/src/ops/fft/cmp/mode.rs rename core/src/ops/fft/{ => cmp}/plan.rs (61%) diff --git a/core/src/error/kinds/external.rs b/core/src/error/kinds/external.rs index 89bdc0e0..eb5d289f 100644 --- a/core/src/error/kinds/external.rs +++ b/core/src/error/kinds/external.rs @@ -67,4 +67,4 @@ impl From> for ExternalError { } } -error_from!(ExternalError::Error<&str, String>); +from_variant!(ExternalError::Error {<&str>.to_string(), .to_string()}); diff --git a/core/src/func/activate.rs b/core/src/func/activate/mod.rs similarity index 87% rename from core/src/func/activate.rs rename to core/src/func/activate/mod.rs index 492ed4d2..d2bfba90 100644 --- a/core/src/func/activate.rs +++ b/core/src/func/activate/mod.rs @@ -1,6 +1,6 @@ /* - Appellation: activate - Contrib: FL03 + Appellation: activate + Contrib: FL03 */ pub use self::{binary::*, nl::*}; diff --git a/core/src/lib.rs b/core/src/lib.rs index 13004e89..76340d86 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -22,6 +22,7 @@ pub use self::rand::{Initialize, InitializeExt}; pub(crate) mod macros; pub(crate) mod primitives; +#[macro_use] pub mod error; pub mod func; pub mod math; diff --git a/core/src/macros.rs b/core/src/macros.rs index df13c3b0..af6a1fec 100644 --- a/core/src/macros.rs +++ b/core/src/macros.rs @@ -2,205 +2,17 @@ Appellation: macros Contrib: FL03 */ -#![allow(unused_macros)] -macro_rules! error_from { - ($base:ident::$variant:ident<$($err:ty),* $(,)?>) => { - error_from!(@loop $base::$variant<$($err),*>); - }; - ($base:ident::$variant:ident<$err:ty>$($rest:tt)*) => { - error_from!(@loop $base::$variant<$($err),*>$($rest)*); - }; - (@loop $base:ident::$variant:ident<$($err:ty),* $(,)?>) => { - $( - error_from!(@impl $base::$variant<$err>); - )* - }; - (@impl $base:ident::$variant:ident<$err:ty>) => { - impl From<$err> for $base { - fn from(err: $err) -> Self { - Self::$variant(err.to_string()) - } - } - }; - (@impl $base:ident::$variant:ident<$err:ty>.$method:ident) => { - impl From<$err> for $base { - fn from(err: $err) -> Self { - Self::$variant(err.$method()) - } - } - }; -} - -macro_rules! nested_constructor { - ($variant:ident<$inner:ident>, $method:ident, [$($call:ident),*]) => { - nested_constructor!(@loop $variant<$inner>, $method, [$($call),*]); - }; - (@loop $variant:ident<$inner:ident>, $method:ident, [$($call:ident),*]) => { - pub fn $method(inner:$inner) -> Self { - Self::$variant(inner) - } - - $( - pub fn $call() -> Self { - Self::$method($inner::$call()) - } - )* - - }; -} - -macro_rules! variant_constructor { - ($($rest:tt),* $(,)?) => { - $( - variant_constructor!(@loop $($rest),*); - )* - }; - ($($variant:ident::$method:ident$(($call:expr))?),* $(,)?) => { - $( - variant_constructor!(@loop $variant::$method$(($call))?); - )* - }; - - (@loop $variant:ident::$method:ident$(($call:expr))?) => { - pub fn $method() -> Self { - Self::$variant$(($call))? - } - }; -} - -macro_rules! impl_unary { - ($name:ident::$call:ident<$T:ty>($f:expr) $($rest:tt)*) => { - impl_unary!(@impl $name::$call<$T>($f) $($rest)*); - }; - (@impl $name:ident::$call:ident<$T:ty>($f:expr)) => { - impl $name for $T { - type Output = $T; - - fn $call(&self) -> Self::Output { - $f(self) - } - } - }; -} - -macro_rules! unary { - ($($name:ident::$call:ident),* $(,)?) => { - $( - unary!(@impl $name::$call(self)); - )* - }; - ($($name:ident::$call:ident(self)),* $(,)?) => { - $( - unary!(@impl $name::$call(self)); - )* - }; - ($($name:ident::$call:ident(&self)),* $(,)?) => { - $( - unary!(@impl $name::$call(&self)); - )* - }; - (@impl $name:ident::$call:ident(self)) => { - pub trait $name { - type Output; - - fn $call(self) -> Self::Output; - } - }; - (@impl $name:ident::$call:ident(&self)) => { - pub trait $name { - type Output; - - fn $call(&self) -> Self::Output; - } - }; -} - -#[macro_export] -macro_rules! builder { - ($(#[derive($($d:ident),+)])?$name:ident::<$inner:ty> {$($k:ident: $v:ty),* $(,)?}) => { - $crate::builder!(@loop builder: $name, derive: [$($($d),+)?], inner: $inner {$($k: $v),*}); - }; - ($(#[derive($($d:ident),+)])? $name:ident($inner:ty) {$($k:ident: $v:ty),* $(,)?}) => { - $crate::builder!(@loop builder: $name, derive: [$($($d),+)?], inner: $inner {$($k: $v),*}); - }; - (@loop builder: $name:ident, derive: [$($d:ident),* $(,)?], inner: $inner:ty {$($k:ident: $v:ty),* $(,)?}) => { - - #[derive(Default, $($d),*)] - pub struct $name { - inner: $inner, - } - - $crate::builder!(@impl builder: $name, inner: $inner {$($k: $v),*}); - }; - (@impl builder: $name:ident, inner: $inner:ty {$($k:ident: $v:ty),* $(,)?}) => { - impl $name { - pub fn new() -> Self { - Self { - inner: Default::default() - } - } - - pub fn from_inner(inner: $inner) -> Self { - Self { inner } - } - - pub fn build(self) -> $inner { - self.inner - } - - $( - pub fn $k(mut self, $k: $v) -> Self { - self.inner.$k = $k; - self - } - )* - } - }; -} - -#[macro_export] -macro_rules! getters { - ($($call:ident$(.$field:ident)?<$out:ty>),* $(,)?) => { - $($crate::getters!(@impl $call$(.$field)?<$out>);)* - }; - ($via:ident::<[$($call:ident$(.$field:ident)?<$out:ty>),* $(,)?]>) => { - $($crate::getters!(@impl $via::$call$(.$field)?<$out>);)* - }; - ($($call:ident$(.$field:ident)?),* $(,)? => $out:ty) => { - $($crate::getters!(@impl $call$(.$field)?<$out>);)* - }; - ($via:ident::<[$($call:ident$(.$field:ident)?),* $(,)?]> => $out:ty) => { - $crate::getters!($via::<[$($call$(.$field)?<$out>),*]>); - }; - - (@impl $call:ident<$out:ty>) => { - $crate::getters!(@impl $call.$call<$out>); - }; - (@impl $via:ident::$call:ident<$out:ty>) => { - $crate::getters!(@impl $via::$call.$call<$out>); - }; - (@impl $call:ident.$field:ident<$out:ty>) => { - pub fn $call(&self) -> &$out { - &self.$field - } - paste::paste! { - pub fn [< $call _mut>](&mut self) -> &mut $out { - &mut self.$field - } - } - }; - (@impl $via:ident::$call:ident.$field:ident<$out:ty>) => { - pub fn $call(&self) -> &$out { - &self.$via.$field - } - paste::paste! { - pub fn [< $call _mut>](&mut self) -> &mut $out { - &mut self.$via.$field - } - } - }; -} +#[macro_use] +mod builder; +#[macro_use] +mod enums; +#[macro_use] +mod getters; +#[macro_use] +mod ops; +#[macro_use] +mod toggle; /// AS #[macro_export] diff --git a/core/src/macros/builder.rs b/core/src/macros/builder.rs new file mode 100644 index 00000000..8fba06d2 --- /dev/null +++ b/core/src/macros/builder.rs @@ -0,0 +1,47 @@ +/* + Appellation: builder + Contrib: FL03 +*/ + +#[macro_export] +macro_rules! builder { + ($(#[derive($($d:ident),+)])?$name:ident::<$inner:ty> {$($k:ident: $v:ty),* $(,)?}) => { + builder!(@loop builder: $name, derive: [$($($d),+)?], inner: $inner {$($k: $v),*}); + }; + ($(#[derive($($d:ident),+)])? $name:ident($inner:ty) {$($k:ident: $v:ty),* $(,)?}) => { + builder!(@loop builder: $name, derive: [$($($d),+)?], inner: $inner {$($k: $v),*}); + }; + (@loop builder: $name:ident, derive: [$($d:ident),* $(,)?], inner: $inner:ty {$($k:ident: $v:ty),* $(,)?}) => { + + #[derive(Default, $($d),*)] + pub struct $name { + inner: $inner, + } + + builder!(@impl builder: $name, inner: $inner {$($k: $v),*}); + }; + (@impl builder: $name:ident, inner: $inner:ty {$($k:ident: $v:ty),* $(,)?}) => { + impl $name { + pub fn new() -> Self { + Self { + inner: Default::default() + } + } + + pub fn from_inner(inner: $inner) -> Self { + Self { inner } + } + + pub fn build(self) -> $inner { + self.inner + } + + $( + pub fn $k(mut self, $k: $v) -> Self { + self.inner.$k = $k; + self + } + )* + } + }; +} diff --git a/core/src/macros/enums.rs b/core/src/macros/enums.rs new file mode 100644 index 00000000..ce08a85a --- /dev/null +++ b/core/src/macros/enums.rs @@ -0,0 +1,44 @@ +/* + Appellation: enums + Contrib: FL03 +*/ + +macro_rules! from_variant { + ($base:ident::$variant:ident $($rest:tt)*) => { + from_variant!(@branch $base::$variant $($rest)*); + }; + (@branch $base:ident::$variant:ident($from:ty)$(.$method:ident())*) => { + from_variant!(@impl $base::$variant($from)$(.$method())*); + }; + (@branch $base:ident::$variant:ident{$(<$err:ty>$(.$method:ident())*),* $(,)?}) => { + $( + from_variant!(@impl $base::$variant($err)$(.$method())*); + )* + }; + (@impl $base:ident::$variant:ident($from:ty)$(.$method:ident())*) => { + impl From<$from> for $base { + fn from(val: $from) -> Self { + Self::$variant(val$(.$method())*) + } + } + }; +} + +#[allow(unused_macros)] +macro_rules! nested_enum_constructor { + ($variant:ident<$inner:ident>, $method:ident, [$($call:ident),*]) => { + nested_enum_constructor!(@loop $variant<$inner>, $method, [$($call),*]); + }; + (@loop $variant:ident<$inner:ident>, $method:ident, [$($call:ident),*]) => { + pub fn $method(inner:$inner) -> Self { + Self::$variant(inner) + } + + $( + pub fn $call() -> Self { + Self::$method($inner::$call()) + } + )* + + }; +} diff --git a/core/src/macros/getters.rs b/core/src/macros/getters.rs new file mode 100644 index 00000000..86d232b4 --- /dev/null +++ b/core/src/macros/getters.rs @@ -0,0 +1,47 @@ +/* + Appellation: getters + Contrib: FL03 +*/ + +#[macro_export] +macro_rules! getters { + ($($call:ident$(.$field:ident)?<$out:ty>),* $(,)?) => { + $($crate::getters!(@impl $call$(.$field)?<$out>);)* + }; + ($via:ident::<[$($call:ident$(.$field:ident)?<$out:ty>),* $(,)?]>) => { + $($crate::getters!(@impl $via::$call$(.$field)?<$out>);)* + }; + ($($call:ident$(.$field:ident)?),* $(,)? => $out:ty) => { + $($crate::getters!(@impl $call$(.$field)?<$out>);)* + }; + ($via:ident::<[$($call:ident$(.$field:ident)?),* $(,)?]> => $out:ty) => { + $crate::getters!($via::<[$($call$(.$field)?<$out>),*]>); + }; + + (@impl $call:ident<$out:ty>) => { + $crate::getters!(@impl $call.$call<$out>); + }; + (@impl $via:ident::$call:ident<$out:ty>) => { + $crate::getters!(@impl $via::$call.$call<$out>); + }; + (@impl $call:ident.$field:ident<$out:ty>) => { + pub fn $call(&self) -> &$out { + &self.$field + } + paste::paste! { + pub fn [< $call _mut>](&mut self) -> &mut $out { + &mut self.$field + } + } + }; + (@impl $via:ident::$call:ident.$field:ident<$out:ty>) => { + pub fn $call(&self) -> &$out { + &self.$via.$field + } + paste::paste! { + pub fn [< $call _mut>](&mut self) -> &mut $out { + &mut self.$via.$field + } + } + }; +} diff --git a/core/src/macros/ops.rs b/core/src/macros/ops.rs new file mode 100644 index 00000000..f37089b0 --- /dev/null +++ b/core/src/macros/ops.rs @@ -0,0 +1,36 @@ +/* + Appellation: ops + Contrib: FL03 +*/ + +macro_rules! unary { + ($($name:ident::$call:ident),* $(,)?) => { + $( + unary!(@impl $name::$call(self)); + )* + }; + ($($name:ident::$call:ident(self)),* $(,)?) => { + $( + unary!(@impl $name::$call(self)); + )* + }; + ($($name:ident::$call:ident(&self)),* $(,)?) => { + $( + unary!(@impl $name::$call(&self)); + )* + }; + (@impl $name:ident::$call:ident(self)) => { + pub trait $name { + type Output; + + fn $call(self) -> Self::Output; + } + }; + (@impl $name:ident::$call:ident(&self)) => { + pub trait $name { + type Output; + + fn $call(&self) -> Self::Output; + } + }; +} diff --git a/core/src/macros/toggle.rs b/core/src/macros/toggle.rs new file mode 100644 index 00000000..908f6abf --- /dev/null +++ b/core/src/macros/toggle.rs @@ -0,0 +1,19 @@ +/* + Appellation: toggle + Contrib: FL03 +*/ + +#[macro_export] +macro_rules! toggle { + (enum $($name:ident),* $(,)?) => { + $(toggle!(@enum $name);)* + }; + + (@enum $name:ident) => { + #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] + #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] + pub enum $name {} + + impl $crate::traits::misc::toggle::Toggle for $name {} + }; +} diff --git a/core/src/ops/fft/cmp.rs b/core/src/ops/fft/cmp/direction.rs similarity index 92% rename from core/src/ops/fft/cmp.rs rename to core/src/ops/fft/cmp/direction.rs index b22eb815..5daff00a 100644 --- a/core/src/ops/fft/cmp.rs +++ b/core/src/ops/fft/cmp/direction.rs @@ -1,7 +1,8 @@ /* - Appellation: cmp - Contrib: FL03 + Appellation: direction + Contrib: FL03 */ +use scsys::VariantConstructors; use strum::{ AsRefStr, Display, EnumCount, EnumIs, EnumIter, EnumString, VariantArray, VariantNames, }; @@ -81,6 +82,7 @@ impl From for usize { PartialEq, PartialOrd, VariantArray, + VariantConstructors, VariantNames, )] #[cfg_attr( diff --git a/core/src/ops/fft/cmp/mode.rs b/core/src/ops/fft/cmp/mode.rs new file mode 100644 index 00000000..7d538eae --- /dev/null +++ b/core/src/ops/fft/cmp/mode.rs @@ -0,0 +1,44 @@ +/* + Appellation: mode + Contrib: FL03 +*/ +use scsys::VariantConstructors; +use strum::{ + AsRefStr, Display, EnumCount, EnumIs, EnumIter, EnumString, VariantArray, VariantNames, +}; + +toggle!(enum C, R); + +/// +#[derive( + AsRefStr, + Clone, + Copy, + Debug, + Default, + Display, + EnumCount, + EnumIs, + EnumIter, + EnumString, + Eq, + Hash, + Ord, + PartialEq, + PartialOrd, + VariantArray, + VariantConstructors, + VariantNames, +)] +#[cfg_attr( + feature = "serde", + derive(serde::Deserialize, serde::Serialize), + serde(rename_all = "lowercase", untagged) +)] +#[repr(usize)] +#[strum(serialize_all = "lowercase")] +pub enum FftMode { + #[default] + Complex, + Real, +} diff --git a/core/src/ops/fft/plan.rs b/core/src/ops/fft/cmp/plan.rs similarity index 61% rename from core/src/ops/fft/plan.rs rename to core/src/ops/fft/cmp/plan.rs index 10a82ed7..d6809acb 100644 --- a/core/src/ops/fft/plan.rs +++ b/core/src/ops/fft/cmp/plan.rs @@ -2,56 +2,62 @@ Appellation: plan Contrib: FL03 */ -#[cfg(all(feature = "alloc", no_std))] -use alloc::vec::{self, Vec}; use core::slice; -#[cfg(feature = "std")] -use std::vec; + +use crate::ops::prelude::fft_permutation; #[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct FftPlan { - n: usize, + len: usize, plan: Vec, } impl FftPlan { - pub fn new(n: usize) -> Self { - let plan = Vec::with_capacity(n); - Self { n, plan } + pub fn new(len: usize) -> Self { + Self { + len, + plan: Vec::with_capacity(len), + } } pub fn build(self) -> Self { - let mut plan = Vec::with_capacity(self.n); - plan.extend(0..self.n); - - let mut rev = 0; // reverse - let mut pos = 1; // position - while pos < self.n { - let mut bit = self.n >> 1; - while bit & rev != 0 { - rev ^= bit; - bit >>= 1; - } - rev ^= bit; - // This is equivalent to adding 1 to a reversed number - if pos < rev { - // Only swap each element once - plan.swap(pos, rev); - } - pos += 1; - } + let plan = fft_permutation(self.len); Self { plan, ..self } } pub fn clear(&mut self) { - self.n = 0; + self.len = 0; self.plan.clear(); } + pub fn get(&self, index: usize) -> Option<&usize> { + self.plan().get(index) + } + + pub fn iter(&self) -> slice::Iter { + self.plan().iter() + } + + pub fn len(&self) -> usize { + self.len + } + pub fn plan(&self) -> &[usize] { &self.plan } + + pub fn set(&mut self, len: usize) { + self.len = len; + self.plan = Vec::with_capacity(len); + } + + pub fn with(self, len: usize) -> Self { + Self { + len, + plan: Vec::with_capacity(len), + } + } } impl AsRef<[usize]> for FftPlan { @@ -76,15 +82,16 @@ impl FromIterator for FftPlan { fn from_iter>(iter: T) -> Self { let plan = Vec::from_iter(iter); Self { - n: plan.len(), + len: plan.len(), plan, } } } +#[cfg(any(feature = "alloc", feature = "std"))] impl IntoIterator for FftPlan { type Item = usize; - type IntoIter = vec::IntoIter; + type IntoIter = crate::rust::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { self.plan.into_iter() diff --git a/core/src/ops/fft/fft.rs b/core/src/ops/fft/fft.rs index 80e3b337..d493c24d 100644 --- a/core/src/ops/fft/fft.rs +++ b/core/src/ops/fft/fft.rs @@ -4,21 +4,21 @@ */ use super::{FftDirection, FftPlan}; -pub struct FastFourierTransform { +pub struct Fft { direction: FftDirection, plan: FftPlan, } -impl FastFourierTransform { +impl Fft { pub fn new(direction: FftDirection, plan: FftPlan) -> Self { Self { direction, plan } } - pub fn direction(&self) -> FftDirection { + pub const fn direction(&self) -> FftDirection { self.direction } - pub fn plan(&self) -> &FftPlan { + pub const fn plan(&self) -> &FftPlan { &self.plan } } diff --git a/core/src/ops/fft/mod.rs b/core/src/ops/fft/mod.rs index 150558bc..3306c3b5 100644 --- a/core/src/ops/fft/mod.rs +++ b/core/src/ops/fft/mod.rs @@ -4,27 +4,39 @@ */ //! # Fast Fourier Transform //! -//! +//! The `fft` module provides an implementation of the Fast Fourier Transform (FFT) algorithm. +//! The Fast Fourier Transform is an efficient algorithm for computing the Discrete Fourier Transform (DFT). pub use self::prelude::*; pub(crate) mod fft; pub(crate) mod utils; -pub mod cmp; -pub mod plan; +pub mod cmp { + pub use self::prelude::*; + + pub mod direction; + pub mod mode; + pub mod plan; + + pub(crate) mod prelude { + pub use super::direction::FftDirection; + pub use super::mode::FftMode; + pub use super::plan::FftPlan; + } +} -pub trait Fft { - type Data; +/// Trait for computing the Discrete Fourier Transform (DFT) of a sequence. +pub trait DFT { + type Output; - fn rfft(&self) -> Self; + fn dft(&self) -> Self::Output; } pub(crate) mod prelude { - pub use super::cmp::*; + pub use super::cmp::prelude::*; pub use super::fft::*; - pub use super::plan::*; pub use super::utils::*; - pub use super::Fft; + pub use super::DFT; } #[cfg(test)] diff --git a/core/src/ops/fft/utils.rs b/core/src/ops/fft/utils.rs index fd1243f4..c3990202 100644 --- a/core/src/ops/fft/utils.rs +++ b/core/src/ops/fft/utils.rs @@ -169,3 +169,30 @@ where let scale = T::from(n).unwrap().recip(); result.iter().map(|x| x.re() * scale).collect() } + +#[doc(hidden)] +/// Generates a permutation for the Fast Fourier Transform. +pub fn fft_permutation(length: usize) -> Vec { + let mut result = Vec::new(); + result.reserve_exact(length); + for i in 0..length { + result.push(i); + } + let mut reverse = 0_usize; + let mut position = 1_usize; + while position < length { + let mut bit = length >> 1; + while bit & reverse != 0 { + reverse ^= bit; + bit >>= 1; + } + reverse ^= bit; + // This is equivalent to adding 1 to a reversed number + if position < reverse { + // Only swap each element once + result.swap(position, reverse); + } + position += 1; + } + result +} diff --git a/core/src/rand/initialize.rs b/core/src/rand/initialize.rs index d2327111..db72acf0 100644 --- a/core/src/rand/initialize.rs +++ b/core/src/rand/initialize.rs @@ -13,45 +13,46 @@ use rand_distr::{Bernoulli, BernoulliError, Distribution, StandardNormal}; /// This trait provides the base methods required for initializing an [ndarray](ndarray::ArrayBase) with random values. /// [Initialize] is similar to [RandomExt](ndarray_rand::RandomExt), however, it focuses on flexibility while implementing additional /// features geared towards machine-learning models; such as lecun_normal initialization. -pub trait Initialize +pub trait Initialize where D: Dimension, - S: RawData, { + type Data: RawData; /// Generate a random array using the given distribution - fn rand(shape: Sh, distr: Ds) -> ArrayBase + fn rand(shape: Sh, distr: Ds) -> Self where - S: DataOwned, - Ds: Distribution, - Sh: ShapeBuilder; + Ds: Clone + Distribution, + Sh: ShapeBuilder, + Self::Data: DataOwned; /// Generate a random array using the given distribution and random number generator - fn rand_with(shape: Sh, distr: Ds, rng: &mut R) -> ArrayBase + fn rand_with(shape: Sh, distr: Ds, rng: &mut R) -> Self where R: Rng + ?Sized, - S: DataOwned, - Ds: Distribution, - Sh: ShapeBuilder; + Ds: Clone + Distribution, + Sh: ShapeBuilder, + Self::Data: DataOwned; /// Initialize an array with random values using the given distribution and current shape - fn init_rand(self, distr: Ds) -> ArrayBase + fn init_rand(self, distr: Ds) -> Self where - S: DataOwned, - Ds: Distribution, - Self: Sized; + Ds: Clone + Distribution, + Self: Sized, + Self::Data: DataOwned; /// Initialize an array with random values from the current shape using the given distribution and random number generator - fn init_rand_with(self, distr: Ds, rng: &mut R) -> ArrayBase + fn init_rand_with(self, distr: Ds, rng: &mut R) -> Self where R: Rng + ?Sized, - S: DataOwned, - Ds: Distribution; + Ds: Clone + Distribution, + Self::Data: DataOwned; } /// This trait extends the [Initialize] trait with methods for generating random arrays from various distributions. -pub trait InitializeExt: Initialize +pub trait InitializeExt: Initialize + Sized where + A: Clone, D: Dimension, S: RawData, { - fn bernoulli(shape: Sh, p: Option) -> Result, BernoulliError> + fn bernoulli(shape: Sh, p: Option) -> Result where S: DataOwned, Sh: ShapeBuilder, @@ -61,7 +62,7 @@ where Ok(Self::rand(shape, dist)) } /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution - fn stdnorm(shape: Sh) -> ArrayBase + fn stdnorm(shape: Sh) -> Self where S: DataOwned, Sh: ShapeBuilder, @@ -70,7 +71,7 @@ where Self::rand(shape, StandardNormal) } /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution with a given seed - fn stdnorm_from_seed(shape: Sh, seed: u64) -> ArrayBase + fn stdnorm_from_seed(shape: Sh, seed: u64) -> Self where S: DataOwned, Sh: ShapeBuilder, @@ -83,31 +84,34 @@ where ) } /// A [uniform](rand_distr::uniform::Uniform) generator with values between u(-dk, dk) - fn uniform(shape: Sh, dk: A) -> ArrayBase + fn uniform(shape: Sh, dk: A) -> Self where - A: Clone + Neg + SampleUniform, + A: Neg + SampleUniform, S: DataOwned, Sh: ShapeBuilder, + ::Sampler: Clone, { Self::rand(shape, Uniform::new(dk.clone().neg(), dk)) } /// Generate a random array with values between u(-a, a) where a is the reciprocal of the value at the given axis - fn uniform_along(shape: Sh, axis: usize) -> ArrayBase + fn uniform_along(shape: Sh, axis: usize) -> Self where A: Copy + Float + SampleUniform, S: DataOwned, Sh: ShapeBuilder, + ::Sampler: Clone, { let dim = shape.into_shape().raw_dim().clone(); let dk = A::from(dim[axis]).unwrap().recip(); Self::uniform(dim, dk) } /// A [uniform](rand_distr::uniform::Uniform) generator with values between u(-dk, dk) - fn uniform_between(shape: Sh, a: A, b: A) -> ArrayBase + fn uniform_between(shape: Sh, a: A, b: A) -> Self where A: SampleUniform, S: DataOwned, Sh: ShapeBuilder, + ::Sampler: Clone, { Self::rand(shape, Uniform::new(a, b)) } @@ -115,16 +119,17 @@ where /* ************ Implementations ************ */ -impl Initialize for ArrayBase +impl Initialize for ArrayBase where D: Dimension, S: RawData, ArrayBase: RandomExt, { + type Data = S; fn rand(shape: Sh, distr: Ds) -> ArrayBase where S: DataOwned, - Ds: Distribution, + Ds: Clone + Distribution, Sh: ShapeBuilder, { Self::random(shape, distr) @@ -134,7 +139,7 @@ where where R: Rng + ?Sized, S: DataOwned, - Ds: Distribution, + Ds: Clone + Distribution, Sh: ShapeBuilder, { Self::random_using(shape, distr, rng) @@ -143,7 +148,7 @@ where fn init_rand(self, distr: Ds) -> ArrayBase where S: DataOwned, - Ds: Distribution, + Ds: Clone + Distribution, { Self::rand(self.dim(), distr) } @@ -152,7 +157,7 @@ where where R: Rng + ?Sized, S: DataOwned, - Ds: Distribution, + Ds: Clone + Distribution, { Self::rand_with(self.dim(), distr, rng) } @@ -160,8 +165,9 @@ where impl InitializeExt for U where + A: Clone, D: Dimension, S: RawData, - U: Initialize, + U: Initialize, { } diff --git a/core/src/rand/mod.rs b/core/src/rand/mod.rs index 58bc5810..6cb1c6db 100644 --- a/core/src/rand/mod.rs +++ b/core/src/rand/mod.rs @@ -17,8 +17,6 @@ pub mod gen { #[doc(no_inline)] pub use ndarray_rand as ndrand; #[doc(no_inline)] -pub use ndrand::{RandomExt, SamplingStrategy}; -#[doc(no_inline)] pub use rand; #[doc(no_inline)] pub use rand_distr; diff --git a/core/src/traits/misc/toggle.rs b/core/src/traits/misc/toggle.rs index 0bca252e..4865d33f 100644 --- a/core/src/traits/misc/toggle.rs +++ b/core/src/traits/misc/toggle.rs @@ -5,18 +5,21 @@ pub trait Toggle: 'static {} -pub trait Mode: Toggle { - fn of() -> bool +pub trait OfType { + fn of() -> bool where - K: Toggle, + T: 'static, + Self: 'static, { - core::any::TypeId::of::() == core::any::TypeId::of::() + core::any::TypeId::of::() == core::any::TypeId::of::() } } /* ************* Implementations ************* */ +impl OfType for T {} + macro_rules! impl_toggle { ($($scope:ident$(<$T:ident>)?),* $(,)?) => { $(impl_toggle!(@impl $scope$(<$T>)?);)* diff --git a/core/tests/fft.rs b/core/tests/fft.rs index 279e28e7..61fff93b 100644 --- a/core/tests/fft.rs +++ b/core/tests/fft.rs @@ -12,31 +12,6 @@ use num::traits::Float; const EPSILON: f64 = 1e-6; -fn fft_permutation(length: usize) -> Vec { - let mut result = Vec::new(); - result.reserve_exact(length); - for i in 0..length { - result.push(i); - } - let mut reverse = 0_usize; - let mut position = 1_usize; - while position < length { - let mut bit = length >> 1; - while bit & reverse != 0 { - reverse ^= bit; - bit >>= 1; - } - reverse ^= bit; - // This is equivalent to adding 1 to a reversed number - if position < reverse { - // Only swap each element once - result.swap(position, reverse); - } - position += 1; - } - result -} - lazy_static! { static ref EXPECTED_RFFT: Vec> = vec![ Complex { re: 28.0, im: 0.0 }, diff --git a/core/tests/func.rs b/core/tests/func.rs index 2f888a90..e1a5ccef 100644 --- a/core/tests/func.rs +++ b/core/tests/func.rs @@ -8,9 +8,11 @@ use ndarray::prelude::*; #[test] #[cfg(feature = "rand")] fn test_dropout() { - let arr = Array2::::ones((2, 2)); - assert!(arr.iter().all(|&x| x == 1.0)); + let shape = (512, 2048); + let arr = Array2::::ones(shape); let dropout = Dropout::new(0.5); - let res = dropout.forward(&arr); - assert!(res.iter().any(|&x| x == 0.0)); + let out = dropout.forward(&arr); + + assert!(arr.iter().all(|&x| x == 1.0)); + assert!(out.iter().any(|&x| x == 0.0)); } diff --git a/data/src/params/impls/impl_rand.rs b/data/src/params/impls/impl_rand.rs index 4293d8a1..87b57347 100644 --- a/data/src/params/impls/impl_rand.rs +++ b/data/src/params/impls/impl_rand.rs @@ -12,10 +12,14 @@ use num::Float; impl Parameter where D: Dimension, - T: Float + SampleUniform, + T: Float, StandardNormal: Distribution, { - pub fn init_uniform(mut self, dk: T) -> Self { + pub fn init_uniform(mut self, dk: T) -> Self + where + T: SampleUniform, + ::Sampler: Clone, + { let dim = self.value.dim(); self.value = Array::uniform(dim, dk); self diff --git a/models/linear/src/impls/impl_rand.rs b/models/linear/src/impls/impl_rand.rs index d571afd3..f79e7656 100644 --- a/models/linear/src/impls/impl_rand.rs +++ b/models/linear/src/impls/impl_rand.rs @@ -4,41 +4,40 @@ */ #![cfg(feature = "rand")] -use crate::params::ParamsBase; +use crate::params::{ParamMode, ParamsBase}; use crate::{bias_dim, Linear}; -use concision::prelude::InitializeExt; -use concision::rand::rand_distr::{uniform, Distribution, StandardNormal}; +use concision::rand::rand::Rng; +use concision::rand::rand_distr::{uniform::SampleUniform, Distribution, StandardNormal}; +use concision::{Initialize, InitializeExt}; use nd::*; use num::Float; impl Linear where - A: Float + uniform::SampleUniform, + A: Clone + Float, D: RemoveAxis, - K: 'static, + K: ParamMode, StandardNormal: Distribution, { - pub fn uniform(self) -> Self { + pub fn uniform(self) -> Self + where + A: SampleUniform, + ::Sampler: Clone, + { Self { params: self.params.uniform(), ..self } } - - pub fn uniform_between(self, low: A, high: A) -> Self { - Self { - params: self.params.uniform_between(low, high), - ..self - } - } } -impl ParamsBase, D, K> +impl crate::LinearParams where - A: Float + uniform::SampleUniform, + A: Clone + Float + SampleUniform, D: RemoveAxis, - K: 'static, + K: ParamMode, StandardNormal: Distribution, + ::Sampler: Clone, { /// Computes the reciprocal of the input features. pub(crate) fn dk(&self) -> A { @@ -80,3 +79,113 @@ where } } } + +impl Initialize for Linear +where + D: RemoveAxis, + K: ParamMode, + StandardNormal: Distribution, +{ + type Data = OwnedRepr; + fn rand(shape: Sh, distr: Ds) -> Self + where + Sh: ShapeBuilder, + Ds: Clone + Distribution, + { + Self::from_params(ParamsBase::rand(shape, distr)) + } + + fn rand_with(shape: Sh, distr: Ds, rng: &mut R) -> Self + where + R: Rng + ?Sized, + Ds: Clone + Distribution, + Sh: ShapeBuilder, + { + Self::from_params(ParamsBase::rand_with(shape, distr, rng)) + } + + fn init_rand(self, distr: Ds) -> Self + where + Ds: Clone + Distribution, + Self: Sized, + { + Self::rand(self.dim(), distr) + } + + fn init_rand_with(self, distr: Ds, rng: &mut R) -> Self + where + R: Rng + ?Sized, + Ds: Clone + Distribution, + { + Self::rand_with(self.dim(), distr, rng) + } +} + +impl Initialize for ParamsBase +where + D: RemoveAxis, + K: ParamMode, + S: DataOwned, + StandardNormal: Distribution, +{ + type Data = S; + fn rand(shape: Sh, distr: Dstr) -> Self + where + Sh: ShapeBuilder, + Dstr: Clone + Distribution, + { + let dim = shape.into_shape().raw_dim().clone(); + let bias = if K::BIASED { + Some(ArrayBase::rand(bias_dim(dim.clone()), distr.clone())) + } else { + None + }; + Self { + weight: ArrayBase::rand(dim, distr), + bias, + _mode: core::marker::PhantomData::, + } + } + + fn rand_with(shape: Sh, distr: Ds, rng: &mut R) -> Self + where + R: Rng + ?Sized, + S: DataOwned, + Ds: Clone + Distribution, + Sh: ShapeBuilder, + { + let dim = shape.into_shape().raw_dim().clone(); + let bias = if K::BIASED { + Some(ArrayBase::rand_with( + bias_dim(dim.clone()), + distr.clone(), + rng, + )) + } else { + None + }; + Self { + weight: ArrayBase::rand_with(dim, distr, rng), + bias, + _mode: core::marker::PhantomData::, + } + } + + fn init_rand(self, distr: Ds) -> Self + where + S: DataOwned, + Ds: Clone + Distribution, + Self: Sized, + { + Self::rand(self.dim(), distr) + } + + fn init_rand_with(self, distr: Ds, rng: &mut R) -> Self + where + R: Rng + ?Sized, + S: DataOwned, + Ds: Clone + Distribution, + { + Self::rand_with(self.dim(), distr, rng) + } +} diff --git a/models/linear/src/model/layer.rs b/models/linear/src/model/layer.rs index 0e6a5a79..930d0671 100644 --- a/models/linear/src/model/layer.rs +++ b/models/linear/src/model/layer.rs @@ -48,6 +48,11 @@ where Self { config, params } } + pub(crate) fn from_params(params: LinearParams) -> Self { + let config = Config::::new().with_shape(params.raw_dim()); + Self { config, params } + } + /// Applies an activcation function onto the prediction of the model. pub fn activate(&self, args: &X, func: F) -> Result where @@ -120,6 +125,8 @@ where ..self } } + + concision::dimensional!(params()); } impl Linear From 66c782406f83139c3c73f145a79a6c6efbf69c15 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sat, 18 May 2024 15:23:21 -0500 Subject: [PATCH 20/23] update Signed-off-by: Joe McCain III --- .github/workflows/clippy.yml | 2 +- .github/workflows/crates.yml | 2 +- .github/workflows/rust.yml | 2 +- core/src/{rand => init}/gen/lecun.rs | 0 core/src/{rand => init}/initialize.rs | 0 core/src/{rand => init}/mod.rs | 0 core/src/{rand => init}/utils.rs | 0 core/src/lib.rs | 11 +- core/src/types/mod.rs | 3 +- core/tests/random.rs | 2 +- data/src/params/impls/impl_rand.rs | 4 +- models/linear/src/impls/impl_rand.rs | 4 +- models/linear/tests/linear.rs | 26 ++-- models/transformers/src/attention/head.rs | 62 +++++---- models/transformers/src/attention/mod.rs | 7 +- models/transformers/src/impls/impl_head.rs | 30 ++++- models/transformers/src/impls/impl_linalg.rs | 12 +- models/transformers/src/impls/impl_params.rs | 12 +- models/transformers/src/lib.rs | 1 + models/transformers/src/macros.rs | 14 +- models/transformers/src/ops/merge.rs | 52 ++++++++ models/transformers/src/ops/mod.rs | 130 +++++++++++++++++++ models/transformers/src/ops/split.rs | 50 +++++++ models/transformers/src/params/mod.rs | 6 +- models/transformers/src/params/store.rs | 4 +- 25 files changed, 354 insertions(+), 82 deletions(-) rename core/src/{rand => init}/gen/lecun.rs (100%) rename core/src/{rand => init}/initialize.rs (100%) rename core/src/{rand => init}/mod.rs (100%) rename core/src/{rand => init}/utils.rs (100%) create mode 100644 models/transformers/src/ops/merge.rs create mode 100644 models/transformers/src/ops/mod.rs create mode 100644 models/transformers/src/ops/split.rs diff --git a/.github/workflows/clippy.yml b/.github/workflows/clippy.yml index 8ad78abc..39d91bd1 100644 --- a/.github/workflows/clippy.yml +++ b/.github/workflows/clippy.yml @@ -1,4 +1,4 @@ -name: Clippy +name: clippy on: pull_request: diff --git a/.github/workflows/crates.yml b/.github/workflows/crates.yml index 467e3a1c..b7c226d0 100644 --- a/.github/workflows/crates.yml +++ b/.github/workflows/crates.yml @@ -1,4 +1,4 @@ -name: crates.io +name: crates concurrency: group: ${{ github.workflow }}-${{ github.ref }} diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 2ed6c641..d8afa861 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -1,4 +1,4 @@ -name: Rust +name: rust concurrency: cancel-in-progress: false diff --git a/core/src/rand/gen/lecun.rs b/core/src/init/gen/lecun.rs similarity index 100% rename from core/src/rand/gen/lecun.rs rename to core/src/init/gen/lecun.rs diff --git a/core/src/rand/initialize.rs b/core/src/init/initialize.rs similarity index 100% rename from core/src/rand/initialize.rs rename to core/src/init/initialize.rs diff --git a/core/src/rand/mod.rs b/core/src/init/mod.rs similarity index 100% rename from core/src/rand/mod.rs rename to core/src/init/mod.rs diff --git a/core/src/rand/utils.rs b/core/src/init/utils.rs similarity index 100% rename from core/src/rand/utils.rs rename to core/src/init/utils.rs diff --git a/core/src/lib.rs b/core/src/lib.rs index 76340d86..0c6d4ec8 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -16,7 +16,7 @@ pub use self::nn::Module; pub use self::{primitives::*, traits::prelude::*, types::prelude::*, utils::prelude::*}; #[cfg(feature = "rand")] -pub use self::rand::{Initialize, InitializeExt}; +pub use self::init::{Initialize, InitializeExt}; #[macro_use] pub(crate) mod macros; @@ -25,11 +25,12 @@ pub(crate) mod primitives; #[macro_use] pub mod error; pub mod func; +#[cfg(feature = "rand")] +pub mod init; pub mod math; pub mod nn; pub mod ops; -#[cfg(feature = "rand")] -pub mod rand; + pub mod traits; pub mod types; pub mod utils; @@ -39,12 +40,12 @@ pub mod prelude { pub use super::error::prelude::*; pub use super::func::prelude::*; + #[cfg(feature = "rand")] + pub use super::init::prelude::*; pub use super::math::prelude::*; pub use super::nn::prelude::*; pub use super::ops::prelude::*; pub use super::primitives::*; - #[cfg(feature = "rand")] - pub use super::rand::prelude::*; pub use super::traits::prelude::*; pub use super::types::prelude::*; pub use super::utils::prelude::*; diff --git a/core/src/types/mod.rs b/core/src/types/mod.rs index 926a1387..a639d5a4 100644 --- a/core/src/types/mod.rs +++ b/core/src/types/mod.rs @@ -8,6 +8,7 @@ pub use self::std_types::*; pub mod propagate; +pub type NdResult = core::result::Result; /// A type alias for a [Result](core::result::Result) with the crate's [Error](crate::error::Error) type. /// Defaults to `Result<(), Error>` pub type Result = core::result::Result; @@ -24,7 +25,7 @@ pub(crate) mod prelude { pub use super::propagate::Propagate; #[cfg(feature = "std")] pub use super::std_types::*; - pub use super::Result; + pub use super::{NdResult, Result}; } #[cfg(test)] diff --git a/core/tests/random.rs b/core/tests/random.rs index 1e7343a4..daa76435 100644 --- a/core/tests/random.rs +++ b/core/tests/random.rs @@ -4,7 +4,7 @@ */ extern crate concision_core as cnc; -use cnc::rand::InitializeExt; +use cnc::init::InitializeExt; use ndarray::prelude::*; #[test] diff --git a/data/src/params/impls/impl_rand.rs b/data/src/params/impls/impl_rand.rs index 87b57347..d5f066cc 100644 --- a/data/src/params/impls/impl_rand.rs +++ b/data/src/params/impls/impl_rand.rs @@ -3,8 +3,8 @@ Contrib: FL03 */ use crate::params::Parameter; -use concision::rand::rand_distr::uniform::SampleUniform; -use concision::rand::rand_distr::{Distribution, StandardNormal}; +use concision::init::rand_distr::uniform::SampleUniform; +use concision::init::rand_distr::{Distribution, StandardNormal}; use concision::InitializeExt; use ndarray::{Array, Dimension}; use num::Float; diff --git a/models/linear/src/impls/impl_rand.rs b/models/linear/src/impls/impl_rand.rs index f79e7656..f2e602e1 100644 --- a/models/linear/src/impls/impl_rand.rs +++ b/models/linear/src/impls/impl_rand.rs @@ -6,8 +6,8 @@ use crate::params::{ParamMode, ParamsBase}; use crate::{bias_dim, Linear}; -use concision::rand::rand::Rng; -use concision::rand::rand_distr::{uniform::SampleUniform, Distribution, StandardNormal}; +use concision::init::rand::Rng; +use concision::init::rand_distr::{uniform::SampleUniform, Distribution, StandardNormal}; use concision::{Initialize, InitializeExt}; use nd::*; use num::Float; diff --git a/models/linear/tests/linear.rs b/models/linear/tests/linear.rs index 165a75b2..6155e332 100644 --- a/models/linear/tests/linear.rs +++ b/models/linear/tests/linear.rs @@ -29,6 +29,21 @@ fn test_config() { assert!(!config.is_biased()); } +#[test] +fn test_model_toggle() { + let (_samples, (outputs, inputs)) = SHAPE; + + let model = Linear::::from_features(inputs, outputs); + assert!(model.is_biased()); + + let model = Linear::::from_features(inputs, outputs); + assert!(!model.is_biased()); + + let model = Linear::::from_features(inputs, outputs).into_unbiased(); + assert!(!model.is_biased()); +} + + #[test] #[cfg(feature = "rand")] fn test_linear() { @@ -42,14 +57,3 @@ fn test_linear() { assert_eq!(y.shape(), &[samples, outputs]); } -#[test] -#[cfg(feature = "rand")] -fn test_model_modes() { - let (_samples, (outputs, inputs)) = SHAPE; - - let model = Linear::::from_features(inputs, outputs).uniform(); - assert!(model.is_biased()); - - let model = Linear::::from_features(inputs, outputs).uniform(); - assert!(!model.is_biased()); -} diff --git a/models/transformers/src/attention/head.rs b/models/transformers/src/attention/head.rs index 8eca2693..10e0e863 100644 --- a/models/transformers/src/attention/head.rs +++ b/models/transformers/src/attention/head.rs @@ -2,28 +2,29 @@ Appellation: head Contrib: FL03 */ -use crate::params::ParamsBase; +use crate::params::QkvBase; use concision::getters; +use core::borrow::{Borrow, BorrowMut}; use nd::linalg::Dot; use nd::*; use num::complex::ComplexFloat; -use num::traits::FromPrimitive; -pub struct AttentionHead, D = Ix2> +pub struct AttentionHead> where D: Dimension, S: RawData, { - params: ParamsBase, + pub(crate) mask: Option>, + pub(crate) params: QkvBase, } -impl AttentionHead +impl AttentionHead where D: Dimension, S: RawData, { - pub fn from_params(params: ParamsBase) -> Self { - Self { params } + pub fn from_params(params: QkvBase) -> Self { + Self { mask: None, params } } pub fn builder(shape: Sh, builder: F) -> Self @@ -31,7 +32,7 @@ where F: Fn(D) -> ArrayBase, Sh: ShapeBuilder, { - Self::from_params(ParamsBase::builder(shape, builder)) + Self::from_params(QkvBase::builder(shape, builder)) } pub fn from_elem(shape: Sh, value: A) -> Self @@ -40,29 +41,32 @@ where A: Clone, S: DataOwned, { - Self::from_params(ParamsBase::from_elem(shape, value)) + Self::from_params(QkvBase::from_elem(shape, value)) } - #[allow(dead_code)] - pub(crate) fn dk(&self) -> A + + pub fn attention(&self) -> Array where - A: FromPrimitive, + A: ComplexFloat + ScalarOperand, + S: Data, + ArrayBase: for<'a> Dot, Output = Array>, + Array: Dot, Output = Array>, { - A::from_usize(self.k().len_of(Axis(1))).unwrap() + let (q, k, v) = self.qkv(); + crate::attention::scaled_dot_product_attention(q, k, v) } - /// Returns an immuable reference to the underlying parameters. - pub const fn params(&self) -> &ParamsBase { + pub const fn params(&self) -> &QkvBase { &self.params } /// Returns a mutable reference to the underlying parameters. - pub fn params_mut(&mut self) -> &mut ParamsBase { + pub fn params_mut(&mut self) -> &mut QkvBase { &mut self.params } - + /// Returns a three-tuple consisting of immputable references to the query, key, and value matrices respectively. pub fn qkv(&self) -> (&ArrayBase, &ArrayBase, &ArrayBase) { self.params().qkv() } - + /// Consumes the head, returning a three-tuple consisting of mutable references to the query, key, and value matrices respectively. pub fn into_qkv(self) -> (ArrayBase, ArrayBase, ArrayBase) { self.params.into_qkv() } @@ -73,16 +77,22 @@ where ndbuilder!(zeros() where A: Clone + num::Zero, S: DataOwned); } -impl AttentionHead, D> +impl Borrow> for AttentionHead where D: Dimension, + S: RawData, { - pub fn attention(&self) -> Array - where - A: ComplexFloat + ScalarOperand, - Array: Dot, Output = Array>, - { - let (q, k, v) = self.qkv(); - crate::attention::scaled_dot_product_attention(q, k, v) + fn borrow(&self) -> &QkvBase { + self.params() + } +} + +impl BorrowMut> for AttentionHead +where + D: Dimension, + S: RawData, +{ + fn borrow_mut(&mut self) -> &mut QkvBase { + self.params_mut() } } diff --git a/models/transformers/src/attention/mod.rs b/models/transformers/src/attention/mod.rs index e317394d..2b7e1060 100644 --- a/models/transformers/src/attention/mod.rs +++ b/models/transformers/src/attention/mod.rs @@ -17,7 +17,7 @@ pub(crate) mod prelude { pub(crate) mod utils { use concision::func::activate::Softmax; use nd::linalg::Dot; - use nd::prelude::{Array, ArrayBase, Axis, Dimension}; + use nd::prelude::{Array, ArrayBase, ArrayView, Axis, Dimension}; use nd::{Data, ScalarOperand}; use num::complex::ComplexFloat; @@ -28,6 +28,7 @@ pub(crate) mod utils { A::from(dk).unwrap().sqrt().recip() } + /// Scaled dot-product attention; pub fn scaled_dot_product_attention( q: &ArrayBase, k: &ArrayBase, @@ -37,10 +38,10 @@ pub(crate) mod utils { A: ComplexFloat + ScalarOperand, S: Data, D: Dimension, - ArrayBase: Dot, Output = Array>, + ArrayBase: for<'a> Dot, Output = Array>, Array: Dot, Output = Array>, { let dk = scale::(k.len_of(Axis(1))); - (q.dot(&k.t().to_owned()) * dk).softmax().dot(&v) + (q.dot(&k.t()) * dk).softmax().dot(&v) } } diff --git a/models/transformers/src/impls/impl_head.rs b/models/transformers/src/impls/impl_head.rs index ec8e410f..fa22f80a 100644 --- a/models/transformers/src/impls/impl_head.rs +++ b/models/transformers/src/impls/impl_head.rs @@ -3,17 +3,39 @@ Contrib: FL03 */ use crate::attention::AttentionHead; -use crate::params::ParamsBase; +use crate::params::QkvBase; use nd::prelude::*; -use nd::DataOwned; +use nd::{DataOwned, RawDataClone}; -impl Default for AttentionHead +impl Clone for AttentionHead +where + A: Copy, + D: Dimension, + S: RawDataClone, +{ + fn clone(&self) -> Self { + Self { + mask: self.mask.clone(), + params: self.params.clone(), + } + } +} + +impl Copy for AttentionHead +where + A: Copy, + D: Copy + Dimension, + S: Copy + RawDataClone, +{ +} + +impl Default for AttentionHead where A: Default, D: Dimension, S: DataOwned, { fn default() -> Self { - Self::from_params(ParamsBase::default()) + Self::from_params(QkvBase::default()) } } diff --git a/models/transformers/src/impls/impl_linalg.rs b/models/transformers/src/impls/impl_linalg.rs index a50ebe63..ce069afe 100644 --- a/models/transformers/src/impls/impl_linalg.rs +++ b/models/transformers/src/impls/impl_linalg.rs @@ -2,12 +2,12 @@ Appellation: impl_linalg Contrib: FL03 */ -use crate::params::{Params, ParamsBase}; +use crate::params::{Params, QkvBase}; use concision::Matmul; use nd::linalg::Dot; use nd::*; -impl Matmul> for ParamsBase +impl Matmul> for QkvBase where A: LinalgScalar, D: Dimension, @@ -19,8 +19,8 @@ where { type Output = Params; - fn matmul(&self, rhs: &ParamsBase) -> Self::Output { - ParamsBase { + fn matmul(&self, rhs: &QkvBase) -> Self::Output { + QkvBase { q: self.q().dot(rhs.q()), k: self.k().dot(rhs.k()), v: self.v().dot(rhs.v()), @@ -28,7 +28,7 @@ where } } -impl Matmul> for ParamsBase +impl Matmul> for QkvBase where A: LinalgScalar, D: Dimension, @@ -41,7 +41,7 @@ where type Output = Params; fn matmul(&self, rhs: &ArrayBase) -> Self::Output { - ParamsBase { + QkvBase { q: self.q().dot(rhs), k: self.k().dot(rhs), v: self.v().dot(rhs), diff --git a/models/transformers/src/impls/impl_params.rs b/models/transformers/src/impls/impl_params.rs index 47c9074a..9736c1b0 100644 --- a/models/transformers/src/impls/impl_params.rs +++ b/models/transformers/src/impls/impl_params.rs @@ -2,11 +2,11 @@ Appellation: impl_params Contrib: FL03 */ -use crate::params::ParamsBase; +use crate::params::QkvBase; use nd::prelude::*; use nd::{Data, DataOwned, RawDataClone}; -impl Clone for ParamsBase +impl Clone for QkvBase where D: Dimension, S: RawDataClone, @@ -20,14 +20,14 @@ where } } -impl Copy for ParamsBase +impl Copy for QkvBase where D: Copy + Dimension, S: Copy + RawDataClone, { } -impl Default for ParamsBase +impl Default for QkvBase where D: Dimension, S: DataOwned, @@ -42,7 +42,7 @@ where } } -impl PartialEq for ParamsBase +impl PartialEq for QkvBase where A: PartialEq, D: Dimension, @@ -53,7 +53,7 @@ where } } -impl PartialEq> for ParamsBase +impl PartialEq> for QkvBase where A: PartialEq, B: PartialEq, diff --git a/models/transformers/src/lib.rs b/models/transformers/src/lib.rs index af17dd30..7a40e9c8 100644 --- a/models/transformers/src/lib.rs +++ b/models/transformers/src/lib.rs @@ -26,6 +26,7 @@ pub(crate) mod transformer; pub mod attention; pub mod codec; +pub mod ops; pub mod params; pub(crate) mod impls { diff --git a/models/transformers/src/macros.rs b/models/transformers/src/macros.rs index 95523baf..fd05142e 100644 --- a/models/transformers/src/macros.rs +++ b/models/transformers/src/macros.rs @@ -26,7 +26,7 @@ macro_rules! ndview { ndview!(@impl $method.$call::$($rest)*); }; (@impl $method:ident.$call:ident::<$view:ident>(self) where $($rest:tt)*) => { - pub fn $method(self) -> $crate::params::ParamsBase<$view, D> + pub fn $method(self) -> $crate::params::QkvBase<$view, D> where $($rest)* { @@ -34,7 +34,7 @@ macro_rules! ndview { } }; (@impl $method:ident.$call:ident::<$view:ident>(mut self) where $($rest:tt)*) => { - pub fn $method(mut self) -> $crate::params::ParamsBase<$view, D> + pub fn $method(mut self) -> $crate::params::QkvBase<$view, D> where $($rest)* { @@ -42,7 +42,7 @@ macro_rules! ndview { } }; (@impl $method:ident.$call:ident::<$view:ident>(&self) where $($rest:tt)*) => { - pub fn $method(&self) -> $crate::params::ParamsBase<$view, D> + pub fn $method(&self) -> $crate::params::QkvBase<$view, D> where $($rest)* { @@ -50,7 +50,7 @@ macro_rules! ndview { } }; (@impl $method:ident.$call:ident::<$view:ident>(&mut self) where $($rest:tt)*) => { - pub fn $method(&mut self) -> $crate::params::ParamsBase<$view, D> + pub fn $method(&mut self) -> $crate::params::QkvBase<$view, D> where $($rest)* { @@ -58,7 +58,7 @@ macro_rules! ndview { } }; (@impl $method:ident.$call:ident::<'a, $view:ident>(&self) where $($rest:tt)*) => { - pub fn $method(&self) -> $crate::params::ParamsBase<$view<&'_ A>, D> + pub fn $method(&self) -> $crate::params::QkvBase<$view<&'_ A>, D> where $($rest)* { @@ -66,7 +66,7 @@ macro_rules! ndview { } }; (@impl $method:ident.$call:ident::<'a, $view:ident>(&mut self) where $($rest:tt)*) => { - pub fn $method(&mut self) -> $crate::params::ParamsBase<$view<&'_ mut A>, D> + pub fn $method(&mut self) -> $crate::params::QkvBase<$view<&'_ mut A>, D> where $($rest)* { @@ -74,7 +74,7 @@ macro_rules! ndview { } }; (@apply $call:ident($self:expr)) => { - $crate::params::ParamsBase { + $crate::params::QkvBase { q: $self.q.$call(), k: $self.k.$call(), v: $self.v.$call(), diff --git a/models/transformers/src/ops/merge.rs b/models/transformers/src/ops/merge.rs new file mode 100644 index 00000000..f7ada237 --- /dev/null +++ b/models/transformers/src/ops/merge.rs @@ -0,0 +1,52 @@ +/* + Appellation: merge + Contrib: FL03 +*/ +use concision::NdResult; +use nd::prelude::*; +use nd::{Data, Order}; + +pub trait Merge { + type Output; + + fn merge(self) -> Self::Output; +} + +/* + ************* Implementations ************* +*/ +impl Merge for ArrayBase +where + A: Clone, + S: Data, +{ + type Output = NdResult>; + + fn merge(self) -> Self::Output { + let (heads, seq, query) = self.dim(); + let mut tmp = self; + // swap the head and sequence axes + tmp.swap_axes(0, 1); + // reshape the qkv matrix into a 2d array + let res = tmp.to_shape(((seq, heads * query), Order::ColumnMajor))?; + Ok(res.to_owned()) + } +} + +impl Merge for ArrayBase +where + A: Clone, + S: Data, +{ + type Output = NdResult>; + + fn merge(self) -> Self::Output { + let (batch, heads, seq, query) = self.dim(); + let mut tmp = self; + // swap the head and sequence axes + tmp.swap_axes(1, 2); + // reshape the qkv matrix into a 2d array + let res = tmp.to_shape(((batch, seq, heads * query), Order::ColumnMajor))?; + Ok(res.to_owned()) + } +} diff --git a/models/transformers/src/ops/mod.rs b/models/transformers/src/ops/mod.rs new file mode 100644 index 00000000..ee4e98fa --- /dev/null +++ b/models/transformers/src/ops/mod.rs @@ -0,0 +1,130 @@ +/* + Appellation: ops + Contrib: FL03 +*/ +pub use self::{merge::*, split::*, utils::*}; + +pub(crate) mod merge; +pub(crate) mod split; + +pub(crate) mod utils { + use concision::NdResult; + use nd::prelude::*; + use nd::{Data, Order, RemoveAxis}; + + #[doc(hidden)] + pub fn merge( + z: &mut ArrayBase, + swap: usize, + with: usize, + ) -> NdResult> + where + A: Clone, + S: Data, + D: RemoveAxis, + E: Dimension, + { + let cur = z.raw_dim().as_array_view().to_owned(); + let indicies = (0..cur.ndim()).filter(|&i| i != swap).collect::>(); + let new_axis = cur[swap] * cur[with]; + let mut dim = cur.select(Axis(0), &indicies); + dim[with - 1] = new_axis; + + // swap the head and sequence axes + z.swap_axes(swap, with); + // reshape the qkv matrix into a smaller dimension + // z.to_shape((dim, Order::ColumnMajor)) + unimplemented!() + } + #[doc(hidden)] + pub fn merge_simple( + z: &mut ArrayBase, + dim: E, + swap: usize, + with: usize, + ) -> NdResult> + where + A: Clone, + S: Data, + D: RemoveAxis, + E: Dimension, + { + // swap the head and sequence axes + z.swap_axes(swap, with); + // reshape the qkv matrix into a smaller dimension + z.to_shape((dim, Order::ColumnMajor)) + } + + pub fn merge_heads(heads: &Array3) -> NdResult> + where + A: Clone, + { + let (n, seq, query) = heads.dim(); + let mut tmp = heads.clone(); + // swap the head and sequence axes + tmp.swap_axes(0, 1); + // reshape the qkv matrix into a 2d array + tmp.into_shape((seq, n * query)) + } + + pub fn split_heads(param: &Array2, num_heads: usize) -> NdResult> + where + T: Clone, + { + let dim = param.shape().last().unwrap() / num_heads; + // reshape the qkv matrix into a 3d array + let mut res = param + .clone() + .into_shape((param.shape()[0], num_heads, dim))?; + // swap the sequence and head axes + res.swap_axes(0, 1); + Ok(res) + } + + pub fn merge_batch(heads: &Array4) -> NdResult> + where + T: Clone, + { + let (batch, n, seq, query) = heads.dim(); + let mut tmp = heads.clone(); + // swap the head and sequence axes + tmp.swap_axes(1, 2); + // reshape the qkv matrix into a 2d array + tmp.into_shape((batch, seq, n * query)) + } + + pub fn split_batch(param: &Array3, num_heads: usize) -> NdResult> + where + T: Clone, + { + let dim = param.shape().last().unwrap() / num_heads; + // reshape the qkv matrix into a 3d array + let mut res = + param + .clone() + .into_shape((param.shape()[0], param.shape()[1], num_heads, dim))?; + // swap the sequence and head axes + res.swap_axes(1, 2); + Ok(res) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::Array; + + #[test] + fn reshape_ops() { + let dim_input: [usize; 3] = [2, 4, 6]; // (batch, seq, model) + let dim_split = [2, 2, 4, 3]; // (batch, heads, seq, model) + let data = Array::linspace(1., 48., 48).into_shape(dim_input).unwrap(); + + let a = split_batch(&data, 2).unwrap(); + assert_eq!(a.shape(), &dim_split); + assert_eq!(&a, &data.split(2).unwrap()); + let b = merge_batch(&a).unwrap(); + assert_eq!(b.shape(), &dim_input); + assert_eq!(&b, &data); + } +} diff --git a/models/transformers/src/ops/split.rs b/models/transformers/src/ops/split.rs new file mode 100644 index 00000000..13c367e4 --- /dev/null +++ b/models/transformers/src/ops/split.rs @@ -0,0 +1,50 @@ +/* + Appellation: split + Contrib: FL03 +*/ +use ndarray::prelude::{Array2, Array3, Array4}; +use ndarray::ShapeError; + +// pub fn split(param: &Array, heads: usize) -> Result, ShapeError> { +// let mut dim = param.dim() +// let query = param.shape().last().unwrap() / heads; +// // reshape the qkv matrix into a 3d array +// let mut res = param.clone().into_shape((param.shape()[0], heads, query))?; +// // swap the sequence and head axes +// res.swap_axes(0, 1); +// Ok(res) +// } + +pub trait Split { + type Error; + + fn split(&self, heads: usize) -> Result; +} + +impl Split> for Array2 { + type Error = ShapeError; + + fn split(&self, heads: usize) -> Result, Self::Error> { + let (seq, model) = self.dim(); + let query = model / heads; + // reshape the qkv matrix into a 3d array + let mut res = self.clone().into_shape((seq, heads, query))?; + // swap the sequence and head axes + res.swap_axes(0, 1); + Ok(res) + } +} + +impl Split> for Array3 { + type Error = ShapeError; + + fn split(&self, heads: usize) -> Result, Self::Error> { + let (batch, seq, model) = self.dim(); + let query = model / heads; + // reshape the qkv matrix into a 3d array + let mut res = self.clone().into_shape((batch, seq, heads, query))?; + // swap the sequence and head axes + res.swap_axes(1, 2); + Ok(res) + } +} diff --git a/models/transformers/src/params/mod.rs b/models/transformers/src/params/mod.rs index 40cf9cc1..367f8b2a 100644 --- a/models/transformers/src/params/mod.rs +++ b/models/transformers/src/params/mod.rs @@ -2,7 +2,7 @@ Appellation: params Contrib: FL03 */ -pub use self::{item::*, store::ParamsBase}; +pub use self::{item::*, store::QkvBase}; pub(crate) mod item; pub(crate) mod store; @@ -20,7 +20,7 @@ macro_rules! params_ty { } params_ty!( - ParamsBase: [ + QkvBase: [ Params, ArcParams, ParamsView<&'a ViewRepr>, @@ -30,6 +30,6 @@ params_ty!( #[allow(unused_imports)] pub(crate) mod prelude { pub use super::item::{Entry, QKV}; - pub use super::store::ParamsBase; + pub use super::store::QkvBase; pub use super::{ArcParams, Params}; } diff --git a/models/transformers/src/params/store.rs b/models/transformers/src/params/store.rs index 2b107531..90c13693 100644 --- a/models/transformers/src/params/store.rs +++ b/models/transformers/src/params/store.rs @@ -6,7 +6,7 @@ use concision::{dimensional, getters}; use nd::*; use num::traits::{One, Zero}; -pub struct ParamsBase, D = Ix2> +pub struct QkvBase, D = Ix2> where D: Dimension, S: RawData, @@ -16,7 +16,7 @@ where pub(crate) v: ArrayBase, } -impl ParamsBase +impl QkvBase where D: Dimension, S: RawData, From 0fa7b9e230ce2a06a3af3191e3bd4b6dd7ba27df Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sat, 18 May 2024 15:24:55 -0500 Subject: [PATCH 21/23] update Signed-off-by: Joe McCain III --- models/linear/tests/linear.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/models/linear/tests/linear.rs b/models/linear/tests/linear.rs index 6155e332..a941510d 100644 --- a/models/linear/tests/linear.rs +++ b/models/linear/tests/linear.rs @@ -38,12 +38,11 @@ fn test_model_toggle() { let model = Linear::::from_features(inputs, outputs); assert!(!model.is_biased()); - + let model = Linear::::from_features(inputs, outputs).into_unbiased(); assert!(!model.is_biased()); } - #[test] #[cfg(feature = "rand")] fn test_linear() { @@ -56,4 +55,3 @@ fn test_linear() { assert_eq!(y.shape(), &[samples, outputs]); } - From ed5abd0192d1d0ab9a9d49983a25e7746788fa80 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sun, 19 May 2024 09:43:17 -0500 Subject: [PATCH 22/23] update Signed-off-by: Joe McCain III --- concision/examples/linear.rs | 8 ++- core/src/func/activate/linear.rs | 21 +++++++ core/src/func/activate/mod.rs | 32 +++++------ core/src/func/activate/nl.rs | 2 +- core/src/func/loss/mod.rs | 4 +- core/src/func/loss/reg/avg.rs | 23 +++++++- core/src/func/mod.rs | 1 + core/src/init/gen/lecun.rs | 49 +++++++++++++++++ core/src/init/initialize.rs | 43 ++++++++++++++- core/src/init/mod.rs | 20 +++++-- core/src/init/utils.rs | 94 +++++--------------------------- core/src/lib.rs | 2 - core/src/macros.rs | 3 +- core/src/macros/activate.rs | 36 ++++++++++++ core/src/traits/predict.rs | 10 ---- 15 files changed, 222 insertions(+), 126 deletions(-) create mode 100644 core/src/func/activate/linear.rs create mode 100644 core/src/macros/activate.rs diff --git a/concision/examples/linear.rs b/concision/examples/linear.rs index 5a063c8c..faab3951 100644 --- a/concision/examples/linear.rs +++ b/concision/examples/linear.rs @@ -4,7 +4,8 @@ */ extern crate concision as cnc; -use cnc::prelude::{linarr, Linear, Result, Sigmoid}; +use cnc::linear::Features; +use cnc::prelude::{linarr, InitializeExt, Linear, Result, Sigmoid}; use ndarray::Ix2; fn tracing() { @@ -25,14 +26,15 @@ fn main() -> Result<()> { tracing::info!("Starting linear model example"); let (samples, d_in, d_out) = (20, 5, 3); + let features = Features::new(d_out, d_in); let data = linarr::((samples, d_in)).unwrap(); - let model = Linear::::from_features(d_in, d_out).uniform(); + let model = Linear::::lecun_normal(features, d_in).unwrap(); assert!(model.is_biased()); let y = model.activate(&data, Sigmoid::sigmoid).unwrap(); assert_eq!(y.dim(), (samples, d_out)); - println!("Predictions:\n{:?}", &y); + println!("Predictions:\n{:#?}", &y); Ok(()) } diff --git a/core/src/func/activate/linear.rs b/core/src/func/activate/linear.rs new file mode 100644 index 00000000..d4c9dc7e --- /dev/null +++ b/core/src/func/activate/linear.rs @@ -0,0 +1,21 @@ +/* + Appellation: linear + Contrib: FL03 +*/ + +pub fn linear(x: T) -> T { + x +} + +unary!(LinearActivation::linear(self)); + +impl<'a, T> LinearActivation for &'a T +where + T: Clone, +{ + type Output = T; + + fn linear(self) -> Self::Output { + self.clone() + } +} diff --git a/core/src/func/activate/mod.rs b/core/src/func/activate/mod.rs index d2bfba90..cd1c2f0b 100644 --- a/core/src/func/activate/mod.rs +++ b/core/src/func/activate/mod.rs @@ -2,32 +2,24 @@ Appellation: activate Contrib: FL03 */ -pub use self::{binary::*, nl::*}; +pub use self::{binary::*, linear::*, nl::*}; pub mod binary; +pub mod linear; pub mod nl; -pub fn linear(x: T) -> T { - x -} - -unary!(LinearActivation::linear(self)); - -impl<'a, T> LinearActivation for &'a T -where - T: Clone, -{ - type Output = T; - - fn linear(self) -> Self::Output { - self.clone() - } -} - pub(crate) mod prelude { pub use super::binary::*; + pub use super::linear::*; pub use super::nl::*; - pub use super::{linear, LinearActivation}; + pub use super::{Activate, Evaluate}; +} + +#[doc(hidden)] +pub trait Activate { + type Output; + + fn activate(&self, args: &T) -> Self::Output; } #[doc(hidden)] @@ -36,3 +28,5 @@ pub trait Evaluate { fn eval(&self, args: T) -> Self::Output; } + +activator!(LinearActor::(T::clone) where T: Clone); diff --git a/core/src/func/activate/nl.rs b/core/src/func/activate/nl.rs index fa1d364b..694145c7 100644 --- a/core/src/func/activate/nl.rs +++ b/core/src/func/activate/nl.rs @@ -137,7 +137,7 @@ nonlinear!( f32, f64, Complex, - Complex < f64 > + Complex ]>, ); diff --git a/core/src/func/loss/mod.rs b/core/src/func/loss/mod.rs index b88f91f5..f22b23b6 100644 --- a/core/src/func/loss/mod.rs +++ b/core/src/func/loss/mod.rs @@ -11,8 +11,8 @@ pub(crate) mod prelude { pub use super::Loss; } -pub trait Loss { +pub trait Loss { type Output; - fn loss(&self, cmp: &T) -> Self::Output; + fn loss(&self, a: &A, cmp: &B) -> Self::Output; } diff --git a/core/src/func/loss/reg/avg.rs b/core/src/func/loss/reg/avg.rs index 163ef665..23014ce9 100644 --- a/core/src/func/loss/reg/avg.rs +++ b/core/src/func/loss/reg/avg.rs @@ -37,6 +37,25 @@ pub trait MeanSquaredError { fn mse(&self, target: &Rhs) -> Self::Output; } +losses! { + impl MSE::, ArrayBase, Output = Option>(mse) + where + A: FromPrimitive + Num + Pow + ScalarOperand, + D: Dimension, + S: Data, +} + +losses! { + impl MAE::, ArrayBase, Output = Option>(mae) + where + A: FromPrimitive + Num + ScalarOperand + Signed, + D: Dimension, + S: Data, +} + +/* + ************* Implementations ************* +*/ impl MeanAbsoluteError> for ArrayBase where A: FromPrimitive + Num + ScalarOperand + Signed, @@ -46,7 +65,7 @@ where type Output = Option; fn mae(&self, target: &ArrayBase) -> Self::Output { - (self - target).abs().mean() + mae(self, target) } } @@ -59,6 +78,6 @@ where type Output = Option; fn mse(&self, target: &ArrayBase) -> Self::Output { - (self - target).sqrd().mean() + mse(self, target) } } diff --git a/core/src/func/mod.rs b/core/src/func/mod.rs index 545415bd..bb99ccba 100644 --- a/core/src/func/mod.rs +++ b/core/src/func/mod.rs @@ -5,6 +5,7 @@ //! Functional pub use self::prelude::*; +#[macro_use] pub mod activate; pub mod dropout; pub mod loss; diff --git a/core/src/init/gen/lecun.rs b/core/src/init/gen/lecun.rs index 33ee0cde..cc419280 100644 --- a/core/src/init/gen/lecun.rs +++ b/core/src/init/gen/lecun.rs @@ -2,3 +2,52 @@ Appellation: lecun Contrib: FL03 */ +use num::Float; +use rand_distr::{Distribution, Normal, NormalError, StandardNormal}; + +/// Create a [Normal](rand_distr::Normal) distribution with a standard deviation of sqrt(1/n) +/// where n is the number of inputs. +pub fn lecun_normal(n: usize) -> Result, NormalError> +where + F: Float, + StandardNormal: Distribution, +{ + let std_dev = F::from(n).unwrap().recip().sqrt(); + Normal::new(F::zero(), std_dev) +} + +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct LecunNormal { + n: usize, +} + +impl LecunNormal { + pub fn new(n: usize) -> Self { + Self { n } + } + + pub fn distr(&self) -> Result, NormalError> + where + F: Float, + StandardNormal: Distribution, + { + lecun_normal(self.n) + } + + pub fn std(&self) -> F + where + F: Float, + { + F::from(self.n).unwrap().recip().sqrt() + } +} + +impl Distribution for LecunNormal +where + F: Float, + StandardNormal: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + self.distr().unwrap().sample(rng) + } +} diff --git a/core/src/init/initialize.rs b/core/src/init/initialize.rs index db72acf0..d985f67c 100644 --- a/core/src/init/initialize.rs +++ b/core/src/init/initialize.rs @@ -5,10 +5,11 @@ use core::ops::Neg; use nd::{ArrayBase, DataOwned, Dimension, RawData, ShapeBuilder}; use ndrand::RandomExt; +use num::complex::ComplexDistribution; use num::traits::Float; use rand::{rngs, Rng, SeedableRng}; use rand_distr::uniform::{SampleUniform, Uniform}; -use rand_distr::{Bernoulli, BernoulliError, Distribution, StandardNormal}; +use rand_distr::{Bernoulli, BernoulliError, Distribution, Normal, StandardNormal}; /// This trait provides the base methods required for initializing an [ndarray](ndarray::ArrayBase) with random values. /// [Initialize] is similar to [RandomExt](ndarray_rand::RandomExt), however, it focuses on flexibility while implementing additional @@ -52,15 +53,50 @@ where D: Dimension, S: RawData, { - fn bernoulli(shape: Sh, p: Option) -> Result + fn bernoulli(shape: Sh, p: f64) -> Result where S: DataOwned, Sh: ShapeBuilder, Bernoulli: Distribution, { - let dist = Bernoulli::new(p.unwrap_or(0.5))?; + let dist = Bernoulli::new(p)?; Ok(Self::rand(shape, dist)) } + /// Initialize the object according to the Lecun Initialization scheme. + /// LecunNormal distributions are truncated [Normal](rand_distr::Normal) + /// distributions centered at 0 with a standard deviation equal to the + /// square root of the reciprocal of the number of inputs. + fn lecun_normal(shape: Sh, n: usize) -> Result + where + A: Float, + S: DataOwned, + Sh: ShapeBuilder, + StandardNormal: Distribution, + { + let std = A::from(n).unwrap().recip().sqrt(); + Self::normal(shape, A::zero(), std) + } + /// Given a shape, mean, and standard deviation generate a new object using the [Normal](rand_distr::Normal) distribution + fn normal(shape: Sh, mean: A, std: A) -> Result + where + A: Float, + S: DataOwned, + Sh: ShapeBuilder, + StandardNormal: Distribution, + { + let distr = Normal::new(mean, std)?; + Ok(Self::rand(shape, distr)) + } + + fn randc(shape: Sh, re: A, im: A) -> Self + where + S: DataOwned, + Sh: ShapeBuilder, + ComplexDistribution: Distribution, + { + let distr = ComplexDistribution::new(re, im); + Self::rand(shape, distr) + } /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution fn stdnorm(shape: Sh) -> Self where @@ -126,6 +162,7 @@ where ArrayBase: RandomExt, { type Data = S; + fn rand(shape: Sh, distr: Ds) -> ArrayBase where S: DataOwned, diff --git a/core/src/init/mod.rs b/core/src/init/mod.rs index 6cb1c6db..22ee1bac 100644 --- a/core/src/init/mod.rs +++ b/core/src/init/mod.rs @@ -1,7 +1,14 @@ /* - Appellation: rand - Contrib: FL03 + Appellation: init + Contrib: FL03 */ +//! # Initialization +//! +//! This module implements several initialization primitives for generating tensors using +//! various distributions and strategies. The module is designed to be used in conjuction with +//! the `rand` and `rand_distr` libraries. While `ndarray_rand` provides a `RandomExt` trait, +//! we provide an alternative [Initialize] trait which is designed to be more flexible and +//! better suited for machine-learning workloads. #![cfg(feature = "rand")] pub use self::prelude::*; @@ -9,9 +16,14 @@ pub use self::prelude::*; pub(crate) mod initialize; pub(crate) mod utils; -#[doc(hidden)] pub mod gen { + pub use self::prelude::*; + pub mod lecun; + + pub(crate) mod prelude { + pub use super::lecun::*; + } } #[doc(no_inline)] @@ -22,7 +34,7 @@ pub use rand; pub use rand_distr; pub(crate) mod prelude { - #[doc(hidden)] + pub use super::gen::prelude::*; pub use super::initialize::{Initialize, InitializeExt}; pub use super::utils::*; } diff --git a/core/src/init/utils.rs b/core/src/init/utils.rs index c4ab0470..3994589c 100644 --- a/core/src/init/utils.rs +++ b/core/src/init/utils.rs @@ -3,64 +3,24 @@ Contrib: FL03 */ use ndarray::*; -use ndrand::rand::rngs::StdRng; -use ndrand::rand::{rngs, Rng, SeedableRng}; -use ndrand::rand_distr::{Distribution, StandardNormal}; use ndrand::RandomExt; use num::complex::{Complex, ComplexDistribution}; -use num::traits::real::Real; use num::Num; use rand::distributions::uniform::{SampleUniform, Uniform}; - -pub fn lecun_normal(shape: impl IntoDimension) -> Array -where - D: Dimension, - T: Real + ScalarOperand, - StandardNormal: Distribution, -{ - let dim = shape.into_dimension(); - let n = dim.size(); - let scale = T::from(n).unwrap().recip().sqrt(); - Array::random(dim, StandardNormal) * scale -} - -pub fn lecun_normal_seeded(shape: impl IntoDimension, seed: u64) -> Array -where - D: Dimension, - T: Real + ScalarOperand, - StandardNormal: Distribution, -{ - let dim = shape.into_dimension(); - let n = dim.size(); - let scale = T::from(n).unwrap().recip().sqrt(); - Array::random_using(dim, StandardNormal, &mut rngs::StdRng::seed_from_u64(seed)) * scale -} +use rand::rngs::StdRng; +use rand::{rngs, SeedableRng}; +use rand_distr::{Distribution, StandardNormal}; /// Generate a random array of complex numbers with real and imaginary parts in the range [0, 1) -pub fn randc(shape: impl IntoDimension) -> Array, D> +pub fn randc(shape: impl IntoDimension) -> ArrayBase where + A: Clone + Num, D: Dimension, - T: Clone + Num, - ComplexDistribution: Distribution>, + S: DataOwned>, + ComplexDistribution: Distribution>, { - let distr = ComplexDistribution::::new(T::one(), T::one()); - Array::random(shape, distr) -} -/// -pub fn randcomplex(shape: impl IntoDimension) -> Array, D> -where - D: Dimension, - T: Copy + Num, - StandardNormal: Distribution, -{ - let dim = shape.into_dimension(); - let re = Array::random(dim.clone(), StandardNormal); - let im = Array::random(dim.clone(), StandardNormal); - let mut res = Array::zeros(dim); - ndarray::azip!((re in &re, im in &im, res in &mut res) { - *res = Complex::new(*re, *im); - }); - res + let distr = ComplexDistribution::::new(A::one(), A::one()); + ArrayBase::random(shape, distr) } /// Creates a random array from a uniform distribution using a given key pub fn seeded_uniform( @@ -79,47 +39,23 @@ where &mut rngs::StdRng::seed_from_u64(key), ) } -/// -pub fn seeded_stdnorm(shape: impl IntoDimension, key: u64) -> Array -where - D: Dimension, - StandardNormal: Distribution, -{ - Array::random_using(shape, StandardNormal, &mut rngs::StdRng::seed_from_u64(key)) -} -/// -pub fn randc_normal(key: u64, shape: impl IntoDimension) -> Array, D> -where - D: Dimension, - T: Copy + Num, - StandardNormal: Distribution, -{ - let dim = shape.into_dimension(); - let re = seeded_stdnorm(dim.clone(), key); - let im = seeded_stdnorm(dim.clone(), key); - let mut res = Array::zeros(dim); - azip!((re in &re, im in &im, res in &mut res) { - *res = Complex::new(*re, *im); - }); - res -} /// Given a shape, generate a random array using the StandardNormal distribution -pub fn stdnorm(shape: impl IntoDimension) -> Array +pub fn stdnorm(shape: Sh) -> ArrayBase where D: Dimension, - StandardNormal: Distribution, + S: DataOwned, + Sh: ShapeBuilder, + StandardNormal: Distribution, { - Array::random(shape, StandardNormal) + ArrayBase::random(shape, StandardNormal) } -pub fn stdnorm_from_seed(shape: Sh, seed: u64) -> ArrayBase +pub fn stdnorm_from_seed(shape: Sh, seed: u64) -> ArrayBase where D: Dimension, - R: Rng + ?Sized, S: DataOwned, Sh: ShapeBuilder, StandardNormal: Distribution, - ArrayBase: RandomExt, { ArrayBase::random_using(shape, StandardNormal, &mut StdRng::seed_from_u64(seed)) } diff --git a/core/src/lib.rs b/core/src/lib.rs index 0c6d4ec8..5906aa6f 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -22,10 +22,8 @@ pub use self::init::{Initialize, InitializeExt}; pub(crate) mod macros; pub(crate) mod primitives; -#[macro_use] pub mod error; pub mod func; -#[cfg(feature = "rand")] pub mod init; pub mod math; pub mod nn; diff --git a/core/src/macros.rs b/core/src/macros.rs index af6a1fec..aaf6ceb6 100644 --- a/core/src/macros.rs +++ b/core/src/macros.rs @@ -2,7 +2,8 @@ Appellation: macros Contrib: FL03 */ - +#[macro_use] +mod activate; #[macro_use] mod builder; #[macro_use] diff --git a/core/src/macros/activate.rs b/core/src/macros/activate.rs new file mode 100644 index 00000000..44e6eeb3 --- /dev/null +++ b/core/src/macros/activate.rs @@ -0,0 +1,36 @@ +/* + Appellation: activate + Contrib: FL03 +*/ + +macro_rules! activator { + ($name:ident::<$out:ty>($rho:expr) $($rest:tt)*) => { + #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] + #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] + pub struct $name; + + impl $crate::func::activate::Activate for $name $($rest)* { + type Output = $out; + + fn activate(&self, args: &T) -> Self::Output { + $rho(args) + } + } + }; +} + +macro_rules! losses { + (impl<$($T:ident),* $(,)?> $name:ident::<$lhs:ty, $rhs:ty, Output = $out:ty>($loss:expr) $($rest:tt)*) => { + #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] + #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] + pub struct $name; + + impl<$($T),*> $crate::func::Loss<$lhs, $rhs> for $name $($rest)* { + type Output = $out; + + fn loss(&self, a: &$lhs, b: &$rhs) -> Self::Output { + $loss(a, b) + } + } + }; +} diff --git a/core/src/traits/predict.rs b/core/src/traits/predict.rs index ab842988..632d05b1 100644 --- a/core/src/traits/predict.rs +++ b/core/src/traits/predict.rs @@ -4,16 +4,6 @@ */ use crate::error::PredictError; -#[doc(hidden)] -pub trait Activate: Forward -where - F: for<'a> Fn(&'a Self::Output) -> Self::Output, -{ - fn activate(&self, args: &T, f: F) -> Self::Output { - f(&self.forward(args)) - } -} - /// [Forward] describes an object capable of forward propagation. pub trait Forward { type Output; From eb5a57e75ffa1930d8176bb0acdee5f5092d04f8 Mon Sep 17 00:00:00 2001 From: Joe McCain III Date: Sun, 19 May 2024 12:14:22 -0500 Subject: [PATCH 23/23] update Signed-off-by: Joe McCain III --- concision/examples/linear.rs | 12 +- core/src/func/loss/entropy.rs | 8 ++ core/src/func/loss/mod.rs | 5 + core/src/func/loss/reg/avg.rs | 26 +--- core/src/func/loss/utils.rs | 29 ++++ core/src/init/gen/lecun.rs | 32 ++--- core/src/init/initialize.rs | 8 +- models/linear/src/macros.rs | 114 +-------------- models/linear/src/macros/model.rs | 22 +++ models/linear/src/macros/params.rs | 92 ++++++++++++ models/linear/src/model/layer.rs | 9 +- models/linear/src/norm/layer/model.rs | 47 ++++--- models/linear/src/params/store.rs | 18 +-- models/transformers/Cargo.toml | 16 +++ models/transformers/src/attention/head.rs | 1 + models/transformers/src/attention/mod.rs | 9 +- models/transformers/src/codec/decoder.rs | 21 ++- .../transformers/src/codec/decoder/config.rs | 14 ++ models/transformers/src/codec/encoder.rs | 29 +++- .../transformers/src/codec/encoder/config.rs | 14 ++ models/transformers/src/lib.rs | 4 + models/transformers/src/ops/merge.rs | 46 +++--- models/transformers/src/ops/mod.rs | 131 ++++++++---------- models/transformers/src/ops/split.rs | 18 +-- models/transformers/src/primitives.rs | 10 ++ models/transformers/tests/ops.rs | 52 +++++++ 26 files changed, 475 insertions(+), 312 deletions(-) create mode 100644 core/src/func/loss/utils.rs create mode 100644 models/linear/src/macros/model.rs create mode 100644 models/linear/src/macros/params.rs create mode 100644 models/transformers/src/codec/decoder/config.rs create mode 100644 models/transformers/src/codec/encoder/config.rs create mode 100644 models/transformers/src/primitives.rs create mode 100644 models/transformers/tests/ops.rs diff --git a/concision/examples/linear.rs b/concision/examples/linear.rs index faab3951..64d77b4c 100644 --- a/concision/examples/linear.rs +++ b/concision/examples/linear.rs @@ -24,16 +24,16 @@ fn tracing() { fn main() -> Result<()> { tracing(); tracing::info!("Starting linear model example"); + let samples = 20; + let (dm, dn) = (5, 3); + let features = Features::new(dn, dm); + let data = linarr::((samples, dm)).unwrap(); - let (samples, d_in, d_out) = (20, 5, 3); - let features = Features::new(d_out, d_in); - let data = linarr::((samples, d_in)).unwrap(); - - let model = Linear::::lecun_normal(features, d_in).unwrap(); + let model = Linear::::lecun_normal(features, dm); assert!(model.is_biased()); let y = model.activate(&data, Sigmoid::sigmoid).unwrap(); - assert_eq!(y.dim(), (samples, d_out)); + assert_eq!(y.dim(), (samples, dn)); println!("Predictions:\n{:#?}", &y); Ok(()) diff --git a/core/src/func/loss/entropy.rs b/core/src/func/loss/entropy.rs index 1982a20f..5e966b72 100644 --- a/core/src/func/loss/entropy.rs +++ b/core/src/func/loss/entropy.rs @@ -2,3 +2,11 @@ Appellation: entropy Contrib: FL03 */ + +pub trait Entropy { + type Output; + + fn cross_entropy(&self, target: &T) -> Self::Output; +} + +pub struct CrossEntropy; diff --git a/core/src/func/loss/mod.rs b/core/src/func/loss/mod.rs index f22b23b6..71694cbb 100644 --- a/core/src/func/loss/mod.rs +++ b/core/src/func/loss/mod.rs @@ -2,12 +2,17 @@ Appellation: loss Contrib: FL03 */ +pub use self::reg::prelude::*; +pub use self::{entropy::*, utils::*}; + +pub(crate) mod utils; pub mod entropy; pub mod reg; pub(crate) mod prelude { pub use super::reg::prelude::*; + pub use super::utils::*; pub use super::Loss; } diff --git a/core/src/func/loss/reg/avg.rs b/core/src/func/loss/reg/avg.rs index 23014ce9..25d80beb 100644 --- a/core/src/func/loss/reg/avg.rs +++ b/core/src/func/loss/reg/avg.rs @@ -7,24 +7,6 @@ use nd::prelude::*; use nd::{Data, ScalarOperand}; use num::traits::{FromPrimitive, Num, Pow, Signed}; -pub fn mae(pred: &ArrayBase, target: &ArrayBase) -> Option -where - A: FromPrimitive + Num + ScalarOperand + Signed, - D: Dimension, - S: Data, -{ - (pred - target).abs().mean() -} - -pub fn mse(pred: &ArrayBase, target: &ArrayBase) -> Option -where - A: FromPrimitive + Num + Pow + ScalarOperand, - D: Dimension, - S: Data, -{ - (pred - target).sqrd().mean() -} - pub trait MeanAbsoluteError { type Output; @@ -38,7 +20,7 @@ pub trait MeanSquaredError { } losses! { - impl MSE::, ArrayBase, Output = Option>(mse) + impl MSE::, ArrayBase, Output = Option>(MeanSquaredError::mse) where A: FromPrimitive + Num + Pow + ScalarOperand, D: Dimension, @@ -46,7 +28,7 @@ losses! { } losses! { - impl MAE::, ArrayBase, Output = Option>(mae) + impl MAE::, ArrayBase, Output = Option>(MeanAbsoluteError::mae) where A: FromPrimitive + Num + ScalarOperand + Signed, D: Dimension, @@ -65,7 +47,7 @@ where type Output = Option; fn mae(&self, target: &ArrayBase) -> Self::Output { - mae(self, target) + (target - self).abs().mean() } } @@ -78,6 +60,6 @@ where type Output = Option; fn mse(&self, target: &ArrayBase) -> Self::Output { - mse(self, target) + (target - self).sqrd().mean() } } diff --git a/core/src/func/loss/utils.rs b/core/src/func/loss/utils.rs new file mode 100644 index 00000000..9f61779b --- /dev/null +++ b/core/src/func/loss/utils.rs @@ -0,0 +1,29 @@ +/* + Appellation: utils + Contrib: FL03 +*/ +use crate::math::{Abs, Squared}; +use nd::prelude::*; +use nd::{Data, ScalarOperand}; +use num::traits::{FromPrimitive, Num, Pow, Signed}; + +/// A functional implementation of the mean absolute error loss function which compares two similar +/// [arrays](ndarray::ArrayBase) +pub fn mae(pred: &ArrayBase, target: &ArrayBase) -> Option +where + A: FromPrimitive + Num + ScalarOperand + Signed, + D: Dimension, + S: Data, +{ + (pred - target).abs().mean() +} +/// A functional implementation of the mean squared error loss function that compares two similar +/// [arrays](ndarray::ArrayBase) +pub fn mse(pred: &ArrayBase, target: &ArrayBase) -> Option +where + A: FromPrimitive + Num + Pow + ScalarOperand, + D: Dimension, + S: Data, +{ + (pred - target).sqrd().mean() +} diff --git a/core/src/init/gen/lecun.rs b/core/src/init/gen/lecun.rs index cc419280..b8cae16c 100644 --- a/core/src/init/gen/lecun.rs +++ b/core/src/init/gen/lecun.rs @@ -3,19 +3,12 @@ Contrib: FL03 */ use num::Float; +use rand::Rng; use rand_distr::{Distribution, Normal, NormalError, StandardNormal}; -/// Create a [Normal](rand_distr::Normal) distribution with a standard deviation of sqrt(1/n) -/// where n is the number of inputs. -pub fn lecun_normal(n: usize) -> Result, NormalError> -where - F: Float, - StandardNormal: Distribution, -{ - let std_dev = F::from(n).unwrap().recip().sqrt(); - Normal::new(F::zero(), std_dev) -} - +/// [LecunNormal] is a truncated [normal](rand_distr::Normal) distribution centered at 0 +/// with a standard deviation that is calculated as `σ = sqrt(1/n_in)` +/// where `n_in` is the number of input units. #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct LecunNormal { n: usize, @@ -25,16 +18,20 @@ impl LecunNormal { pub fn new(n: usize) -> Self { Self { n } } - + /// Create a [normal](rand_distr::Normal) [distribution](Distribution) centered at 0; + /// See [Self::std_dev] for the standard deviation calculations. pub fn distr(&self) -> Result, NormalError> where F: Float, StandardNormal: Distribution, { - lecun_normal(self.n) + Normal::new(F::zero(), self.std_dev()) } - - pub fn std(&self) -> F + /// Calculate the standard deviation (`σ`) of the distribution. + /// This is done by computing the root of the reciprocal of the number of inputs + /// + /// Symbolically: `σ = sqrt(1/n)` + pub fn std_dev(&self) -> F where F: Float, { @@ -47,7 +44,10 @@ where F: Float, StandardNormal: Distribution, { - fn sample(&self, rng: &mut R) -> F { + fn sample(&self, rng: &mut R) -> F + where + R: Rng + ?Sized, + { self.distr().unwrap().sample(rng) } } diff --git a/core/src/init/initialize.rs b/core/src/init/initialize.rs index d985f67c..91b41b13 100644 --- a/core/src/init/initialize.rs +++ b/core/src/init/initialize.rs @@ -11,6 +11,8 @@ use rand::{rngs, Rng, SeedableRng}; use rand_distr::uniform::{SampleUniform, Uniform}; use rand_distr::{Bernoulli, BernoulliError, Distribution, Normal, StandardNormal}; +use super::LecunNormal; + /// This trait provides the base methods required for initializing an [ndarray](ndarray::ArrayBase) with random values. /// [Initialize] is similar to [RandomExt](ndarray_rand::RandomExt), however, it focuses on flexibility while implementing additional /// features geared towards machine-learning models; such as lecun_normal initialization. @@ -66,15 +68,15 @@ where /// LecunNormal distributions are truncated [Normal](rand_distr::Normal) /// distributions centered at 0 with a standard deviation equal to the /// square root of the reciprocal of the number of inputs. - fn lecun_normal(shape: Sh, n: usize) -> Result + fn lecun_normal(shape: Sh, n: usize) -> Self where A: Float, S: DataOwned, Sh: ShapeBuilder, StandardNormal: Distribution, { - let std = A::from(n).unwrap().recip().sqrt(); - Self::normal(shape, A::zero(), std) + let distr = LecunNormal::new(n); + Self::rand(shape, distr) } /// Given a shape, mean, and standard deviation generate a new object using the [Normal](rand_distr::Normal) distribution fn normal(shape: Sh, mean: A, std: A) -> Result diff --git a/models/linear/src/macros.rs b/models/linear/src/macros.rs index 99ec867a..b5672fb9 100644 --- a/models/linear/src/macros.rs +++ b/models/linear/src/macros.rs @@ -3,113 +3,7 @@ Contrib: FL03 */ -macro_rules! impl_params_builder { - ($method:ident$(.$call:ident)? where $($rest:tt)*) => { - impl_params_builder!(@impl $method$(.$call)? where $($rest)*); - }; - (@impl $method:ident where $($rest:tt)*) => { - impl_params_builder!(@impl $method.$method where $($rest)*); - }; - (@impl $method:ident.$call:ident where $($rest:tt)*) => { - pub fn $method(shape: Sh) -> Self - where - K: $crate::params::mode::ParamMode, - Sh: ndarray::ShapeBuilder, - $($rest)* - { - let dim = shape.into_shape().raw_dim().clone(); - ParamsBase { - bias: build_bias(K::BIASED, dim.clone(), |dim| ndarray::ArrayBase::$call(dim)), - weight: ndarray::ArrayBase::$call(dim), - _mode: ::core::marker::PhantomData::, - } - } - }; -} - -macro_rules! impl_model_builder { - ($method:ident$(.$call:ident)? where $($rest:tt)*) => { - impl_model_builder!(@impl $method$(.$call)? where $($rest)*); - }; - (@impl $method:ident where $($rest:tt)*) => { - impl_model_builder!(@impl $method.$method where $($rest)*); - }; - (@impl $method:ident.$call:ident where $($rest:tt)*) => { - pub fn $method(shape: Sh) -> Self - where - K: $crate::params::mode::ParamMode, - Sh: ndarray::ShapeBuilder, - $($rest)* - { - let dim = shape.into_shape().raw_dim().clone(); - $crate::model::Linear { - config: $crate::model::Config::::new().with_shape(dim.clone()), - params: $crate::params::ParamsBase::$call(dim), - } - } - }; -} - -macro_rules! ndview { - ($method:ident::$($rest:tt)*) => { - ndview!(@impl $method.$method::$($rest)*); - }; - ($method:ident.$call:ident::$($rest:tt)*) => { - ndview!(@impl $method.$call::$($rest)*); - }; - (@impl $method:ident.$call:ident::<$view:ident>(self) where $($rest:tt)*) => { - pub fn $method(self) -> $crate::params::ParamsBase<$view, D, K> - where - $($rest)* - { - ndview!(@apply $call(self)) - } - }; - (@impl $method:ident.$call:ident::<$view:ident>(mut self) where $($rest:tt)*) => { - pub fn $method(mut self) -> $crate::params::ParamsBase<$view, D, K> - where - $($rest)* - { - ndview!(@apply $call(self).as_mut()) - } - }; - (@impl $method:ident.$call:ident::<$view:ident>(&self) where $($rest:tt)*) => { - pub fn $method(&self) -> $crate::params::ParamsBase<$view, D, K> - where - $($rest)* - { - ndview!(@apply $call(self).as_ref()) - } - }; - (@impl $method:ident.$call:ident::<$view:ident>(&mut self) where $($rest:tt)*) => { - pub fn $method(&mut self) -> $crate::params::ParamsBase<$view, D, K> - where - $($rest)* - { - ndview!(@apply $call(self).as_mut()) - } - }; - (@impl $method:ident.$call:ident::<'a, $view:ident>(&self) where $($rest:tt)*) => { - pub fn $method(&self) -> $crate::params::ParamsBase<$view<&'_ A>, D, K> - where - $($rest)* - { - ndview!(@apply $call(&self).as_ref()) - } - }; - (@impl $method:ident.$call:ident::<'a, $view:ident>(&mut self) where $($rest:tt)*) => { - pub fn $method(&mut self) -> $crate::params::ParamsBase<$view<&'_ mut A>, D, K> - where - $($rest)* - { - ndview!(@apply $call(self).as_mut()) - } - }; - (@apply $call:ident($self:expr)$(.$as:ident())?) => { - $crate::params::ParamsBase { - bias: $self.bias$(.$as())?.map(|arr| arr.$call()), - weight: $self.weight.$call(), - _mode: $self._mode, - } - }; -} +#[macro_use] +mod model; +#[macro_use] +mod params; diff --git a/models/linear/src/macros/model.rs b/models/linear/src/macros/model.rs new file mode 100644 index 00000000..e0cba303 --- /dev/null +++ b/models/linear/src/macros/model.rs @@ -0,0 +1,22 @@ +/* + Appellation: model + Contrib: FL03 +*/ + +macro_rules! mbuilder { + ($method:ident$(.$call:ident)? where $($rest:tt)*) => { + mbuilder!(@impl $method$(.$call)? where $($rest)*); + }; + (@impl $method:ident where $($rest:tt)*) => { + mbuilder!(@impl $method.$method where $($rest)*); + }; + (@impl $method:ident.$call:ident where $($rest:tt)*) => { + pub fn $method(shape: Sh) -> Self + where + Sh: ndarray::ShapeBuilder, + $($rest)* + { + Linear::from_params($crate::params::ParamsBase::$call(shape)) + } + }; +} diff --git a/models/linear/src/macros/params.rs b/models/linear/src/macros/params.rs new file mode 100644 index 00000000..7da00db5 --- /dev/null +++ b/models/linear/src/macros/params.rs @@ -0,0 +1,92 @@ +/* + Appellation: params + Contrib: FL03 +*/ + +macro_rules! pbuilder { + ($method:ident$(.$call:ident)? where $($rest:tt)*) => { + pbuilder!(@impl $method$(.$call)? where $($rest)*); + }; + (@impl $method:ident where $($rest:tt)*) => { + pbuilder!(@impl $method.$method where $($rest)*); + }; + (@impl $method:ident.$call:ident where $($rest:tt)*) => { + pub fn $method(shape: Sh) -> Self + where + K: $crate::params::mode::ParamMode, + Sh: ndarray::ShapeBuilder, + $($rest)* + { + let dim = shape.into_shape().raw_dim().clone(); + ParamsBase { + bias: build_bias(K::BIASED, dim.clone(), |dim| ndarray::ArrayBase::$call(dim)), + weight: ndarray::ArrayBase::$call(dim), + _mode: ::core::marker::PhantomData::, + } + } + }; +} + +macro_rules! wnbview { + ($method:ident::$($rest:tt)*) => { + wnbview!(@impl $method.$method::$($rest)*); + }; + ($method:ident.$call:ident::$($rest:tt)*) => { + wnbview!(@impl $method.$call::$($rest)*); + }; + (@impl $method:ident.$call:ident::<$view:ident>(self) where $($rest:tt)*) => { + pub fn $method(self) -> $crate::params::ParamsBase<$view, D, K> + where + $($rest)* + { + wnbview!(@apply $call(self)) + } + }; + (@impl $method:ident.$call:ident::<$view:ident>(mut self) where $($rest:tt)*) => { + pub fn $method(mut self) -> $crate::params::ParamsBase<$view, D, K> + where + $($rest)* + { + wnbview!(@apply $call(self).as_mut()) + } + }; + (@impl $method:ident.$call:ident::<$view:ident>(&self) where $($rest:tt)*) => { + pub fn $method(&self) -> $crate::params::ParamsBase<$view, D, K> + where + $($rest)* + { + wnbview!(@apply $call(self).as_ref()) + } + }; + (@impl $method:ident.$call:ident::<$view:ident>(&mut self) where $($rest:tt)*) => { + pub fn $method(&mut self) -> $crate::params::ParamsBase<$view, D, K> + where + $($rest)* + { + wnbview!(@apply $call(self).as_mut()) + } + }; + (@impl $method:ident.$call:ident::<'a, $view:ident>(&self) where $($rest:tt)*) => { + pub fn $method(&self) -> $crate::params::ParamsBase<$view<&'_ A>, D, K> + where + $($rest)* + { + wnbview!(@apply $call(&self).as_ref()) + } + }; + (@impl $method:ident.$call:ident::<'a, $view:ident>(&mut self) where $($rest:tt)*) => { + pub fn $method(&mut self) -> $crate::params::ParamsBase<$view<&'_ mut A>, D, K> + where + $($rest)* + { + wnbview!(@apply $call(self).as_mut()) + } + }; + (@apply $call:ident($self:expr)$(.$as:ident())?) => { + $crate::params::ParamsBase { + bias: $self.bias$(.$as())?.map(|arr| arr.$call()), + weight: $self.weight.$call(), + _mode: $self._mode, + } + }; +} diff --git a/models/linear/src/model/layer.rs b/models/linear/src/model/layer.rs index 930d0671..486bb43d 100644 --- a/models/linear/src/model/layer.rs +++ b/models/linear/src/model/layer.rs @@ -24,10 +24,11 @@ where impl Linear where D: RemoveAxis, + K: ParamMode, { - impl_model_builder!(new where A: Default); - impl_model_builder!(ones where A: Clone + num::One); - impl_model_builder!(zeros where A: Clone + num::Zero); + mbuilder!(new where A: Default); + mbuilder!(ones where A: Clone + num::One); + mbuilder!(zeros where A: Clone + num::Zero); pub fn from_config(config: Config) -> Self where @@ -48,7 +49,7 @@ where Self { config, params } } - pub(crate) fn from_params(params: LinearParams) -> Self { + pub fn from_params(params: LinearParams) -> Self { let config = Config::::new().with_shape(params.raw_dim()); Self { config, params } } diff --git a/models/linear/src/norm/layer/model.rs b/models/linear/src/norm/layer/model.rs index 8895f083..e5dc6b67 100644 --- a/models/linear/src/norm/layer/model.rs +++ b/models/linear/src/norm/layer/model.rs @@ -25,6 +25,24 @@ where params: LinearParams, } +macro_rules! impl_norm_builder { + ($method:ident$(.$call:ident)? where $($rest:tt)*) => { + impl_norm_builder!(@impl $method$(.$call)? where $($rest)*); + }; + (@impl $method:ident where $($rest:tt)*) => { + impl_norm_builder!(@impl $method.$method where $($rest)*); + }; + (@impl $method:ident.$call:ident where $($rest:tt)*) => { + pub fn $method(shape: Sh) -> Self + where + Sh: ShapeBuilder, + $($rest)* + { + Self::from_params(LinearParams::::$call(shape)) + } + }; +} + impl LayerNorm where D: RemoveAxis, @@ -38,38 +56,25 @@ where Self { config, params } } - pub fn default(shape: Sh) -> Self + pub fn from_elem(shape: Sh, elem: A) -> Self where - A: Default, + A: Clone, Sh: ShapeBuilder, { let dim = shape.into_shape().raw_dim().clone(); let config = Config::new().dim(dim.clone()).build(); - let params = LinearParams::::new(dim); + let params = LinearParams::::from_elem(dim, elem); Self { config, params } } - pub fn ones(shape: Sh) -> Self - where - A: Clone + One, - Sh: ShapeBuilder, - { - let dim = shape.into_shape().raw_dim().clone(); - let config = Config::new().dim(dim.clone()).build(); - let params = LinearParams::::ones(dim); + pub fn from_params(params: LinearParams) -> Self { + let config = Config::new().dim(params.raw_dim()).build(); Self { config, params } } - pub fn zeros(shape: Sh) -> Self - where - A: Clone + Zero, - Sh: ShapeBuilder, - { - let dim = shape.into_shape().raw_dim().clone(); - let config = Config::new().dim(dim.clone()).build(); - let params = LinearParams::::zeros(dim); - Self { config, params } - } + impl_norm_builder!(new where A: Default); + impl_norm_builder!(ones where A: Clone + One); + impl_norm_builder!(zeros where A: Clone + Zero); pub const fn config(&self) -> &Config { &self.config diff --git a/models/linear/src/params/store.rs b/models/linear/src/params/store.rs index 90e2809c..b8afc752 100644 --- a/models/linear/src/params/store.rs +++ b/models/linear/src/params/store.rs @@ -110,25 +110,25 @@ where crate::is_biased::() } - impl_params_builder!(new.default where A: Default, S: DataOwned); + pbuilder!(new.default where A: Default, S: DataOwned); - impl_params_builder!(ones where A: Clone + One, S: DataOwned); + pbuilder!(ones where A: Clone + One, S: DataOwned); - impl_params_builder!(zeros where A: Clone + Zero, S: DataOwned); + pbuilder!(zeros where A: Clone + Zero, S: DataOwned); dimensional!(weights()); - ndview!(into_owned::(self) where A: Clone, S: Data); + wnbview!(into_owned::(self) where A: Clone, S: Data); - ndview!(into_shared::(self) where A: Clone, S: DataOwned); + wnbview!(into_shared::(self) where A: Clone, S: DataOwned); - ndview!(to_owned::(&self) where A: Clone, S: Data); + wnbview!(to_owned::(&self) where A: Clone, S: Data); - ndview!(to_shared::(&self) where A: Clone, S: DataOwned); + wnbview!(to_shared::(&self) where A: Clone, S: DataOwned); - ndview!(view::<'a, ViewRepr>(&self) where A: Clone, S: Data); + wnbview!(view::<'a, ViewRepr>(&self) where A: Clone, S: Data); - ndview!(view_mut::<'a, ViewRepr>(&mut self) where A: Clone, S: DataMut); + wnbview!(view_mut::<'a, ViewRepr>(&mut self) where A: Clone, S: DataMut); } impl ParamsBase diff --git a/models/transformers/Cargo.toml b/models/transformers/Cargo.toml index 7e91e5e3..00bf9fb7 100644 --- a/models/transformers/Cargo.toml +++ b/models/transformers/Cargo.toml @@ -28,27 +28,33 @@ full = [ alloc = [ "concision-core/alloc", + "concision-linear/alloc", + "serde?/alloc", ] approx = [ "dep:approx", "concision-core/approx", + "concision-linear/approx", "ndarray/approx-0_5", ] blas = [ "concision-core/blas", + "concision-linear/blas", "ndarray/blas", ] rand = [ "concision-core/rand", + "concision-linear/rand", "num/rand" ] serde = [ "serde-1", "concision-core/serde", + "concision-linear/serde", "ndarray/serde-1", "num/serde" ] @@ -59,11 +65,14 @@ serde-1 = [ tracing = [ "dep:tracing", + "concision-core/tracing", + "concision-linear/tracing", ] # ********* [FF] Environments ********* std = [ "concision-core/std", + "concision-linear/std", "ndarray/std", "num/std", "serde?/std", @@ -72,10 +81,12 @@ std = [ wasm = [ "concision-core/wasm", + "concision-linear/wasm", ] wasi = [ "concision-core/wasi", + "concision-linear/wasi", ] [lib] @@ -102,6 +113,11 @@ default-features = false path = "../../core" version = "0.1.14" +[dependencies.concision-linear] +default-features = false +path = "../linear" +version = "0.1.14" + [dependencies.serde] default-features = false features = ["derive"] diff --git a/models/transformers/src/attention/head.rs b/models/transformers/src/attention/head.rs index 10e0e863..c5146a34 100644 --- a/models/transformers/src/attention/head.rs +++ b/models/transformers/src/attention/head.rs @@ -9,6 +9,7 @@ use nd::linalg::Dot; use nd::*; use num::complex::ComplexFloat; +// #68 pub struct AttentionHead> where D: Dimension, diff --git a/models/transformers/src/attention/mod.rs b/models/transformers/src/attention/mod.rs index 2b7e1060..80a264c7 100644 --- a/models/transformers/src/attention/mod.rs +++ b/models/transformers/src/attention/mod.rs @@ -2,11 +2,18 @@ Appellation: attention Contrib: FL03 */ +//! # Attention +//! +//! Attention allows a model to focus on specific parts of the input sequence. +//! Today, these mechanisms are found in several state-of-the-art models, such as +//! the Transformer model, primarily due to its capabilities in natural language +//! processing (NLP) domains pub use self::head::AttentionHead; pub use self::utils::*; pub(crate) mod head; +// #69: Multi-Head Attention implementation pub mod multi; pub(crate) mod prelude { @@ -28,7 +35,7 @@ pub(crate) mod utils { A::from(dk).unwrap().sqrt().recip() } - /// Scaled dot-product attention; + /// A functional implementation of the scaled dot-product attention mechanism; pub fn scaled_dot_product_attention( q: &ArrayBase, k: &ArrayBase, diff --git a/models/transformers/src/codec/decoder.rs b/models/transformers/src/codec/decoder.rs index b8c8dd6a..3d00af1c 100644 --- a/models/transformers/src/codec/decoder.rs +++ b/models/transformers/src/codec/decoder.rs @@ -2,15 +2,30 @@ Appellation: decoder Contrib: FL03 */ -pub use self::layer::DecoderLayer; +pub use self::{config::DecoderConfig, layer::DecoderLayer}; +pub mod config; pub mod layer; #[derive(Default)] -pub struct Decoder {} +pub struct Decoder { + config: DecoderConfig, + layers: Vec, +} impl Decoder { pub fn new() -> Self { - Self {} + Self { + config: DecoderConfig::default(), + layers: Vec::new(), + } + } + + pub const fn config(&self) -> &DecoderConfig { + &self.config + } + + pub fn layers(&self) -> &[DecoderLayer] { + &self.layers } } diff --git a/models/transformers/src/codec/decoder/config.rs b/models/transformers/src/codec/decoder/config.rs new file mode 100644 index 00000000..8056b7cd --- /dev/null +++ b/models/transformers/src/codec/decoder/config.rs @@ -0,0 +1,14 @@ +/* + Appellation: config + Contrib: FL03 +*/ + +pub struct DecoderConfig { + pub layers: usize, +} + +impl Default for DecoderConfig { + fn default() -> Self { + Self { layers: crate::N } + } +} diff --git a/models/transformers/src/codec/encoder.rs b/models/transformers/src/codec/encoder.rs index a9e6c4ad..e9fdb490 100644 --- a/models/transformers/src/codec/encoder.rs +++ b/models/transformers/src/codec/encoder.rs @@ -2,15 +2,38 @@ Appellation: encoder Contrib: FL03 */ -pub use self::layer::EncoderLayer; +pub use self::{config::EncoderConfig, layer::EncoderLayer}; +pub mod config; pub mod layer; +use linear::norm::LayerNorm; + #[derive(Default)] -pub struct Encoder {} +pub struct Encoder { + config: EncoderConfig, + layers: Vec, + norm: LayerNorm, +} impl Encoder { pub fn new() -> Self { - Self {} + Self { + config: EncoderConfig::default(), + layers: Vec::new(), + norm: LayerNorm::default(), + } + } + + pub const fn config(&self) -> &EncoderConfig { + &self.config + } + + pub fn layers(&self) -> &[EncoderLayer] { + &self.layers + } + + pub fn norm(&self) -> &LayerNorm { + &self.norm } } diff --git a/models/transformers/src/codec/encoder/config.rs b/models/transformers/src/codec/encoder/config.rs new file mode 100644 index 00000000..2c5dcf93 --- /dev/null +++ b/models/transformers/src/codec/encoder/config.rs @@ -0,0 +1,14 @@ +/* + Appellation: config + Contrib: FL03 +*/ + +pub struct EncoderConfig { + pub layers: usize, +} + +impl Default for EncoderConfig { + fn default() -> Self { + Self { layers: crate::N } + } +} diff --git a/models/transformers/src/lib.rs b/models/transformers/src/lib.rs index 7a40e9c8..ed9cf63e 100644 --- a/models/transformers/src/lib.rs +++ b/models/transformers/src/lib.rs @@ -14,14 +14,17 @@ extern crate alloc; extern crate concision_core as concision; +extern crate concision_linear as linear; extern crate ndarray as nd; pub use self::attention::AttentionHead; pub use self::params::*; +pub use self::primitives::*; pub use self::transformer::Transformer; #[macro_use] pub(crate) mod macros; +pub(crate) mod primitives; pub(crate) mod transformer; pub mod attention; @@ -37,5 +40,6 @@ pub(crate) mod impls { pub mod prelude { pub use super::attention::prelude::*; + pub use super::primitives::*; pub use super::Transformer; } diff --git a/models/transformers/src/ops/merge.rs b/models/transformers/src/ops/merge.rs index f7ada237..747cae82 100644 --- a/models/transformers/src/ops/merge.rs +++ b/models/transformers/src/ops/merge.rs @@ -4,49 +4,39 @@ */ use concision::NdResult; use nd::prelude::*; -use nd::{Data, Order}; +use nd::{Data, RemoveAxis}; +// #67: Optimize the Merge trait pub trait Merge { type Output; - fn merge(self) -> Self::Output; + fn merge(&self) -> NdResult { + self.merge_along(0) + } + + fn merge_along(&self, axis: usize) -> NdResult; } /* ************* Implementations ************* */ -impl Merge for ArrayBase +impl Merge for ArrayBase where A: Clone, + D: RemoveAxis, + E: Dimension, S: Data, + ArrayBase: Clone, { - type Output = NdResult>; + type Output = Array; - fn merge(self) -> Self::Output { - let (heads, seq, query) = self.dim(); - let mut tmp = self; - // swap the head and sequence axes - tmp.swap_axes(0, 1); - // reshape the qkv matrix into a 2d array - let res = tmp.to_shape(((seq, heads * query), Order::ColumnMajor))?; - Ok(res.to_owned()) + fn merge(&self) -> NdResult { + let swap = if self.ndim() >= 3 { self.ndim() - 3 } else { 0 }; + self.merge_along(swap) } -} - -impl Merge for ArrayBase -where - A: Clone, - S: Data, -{ - type Output = NdResult>; - fn merge(self) -> Self::Output { - let (batch, heads, seq, query) = self.dim(); - let mut tmp = self; - // swap the head and sequence axes - tmp.swap_axes(1, 2); - // reshape the qkv matrix into a 2d array - let res = tmp.to_shape(((batch, seq, heads * query), Order::ColumnMajor))?; - Ok(res.to_owned()) + fn merge_along(&self, swap: usize) -> NdResult { + use ndarray::Order; + super::merger(self, swap, swap + 1, Order::RowMajor) } } diff --git a/models/transformers/src/ops/mod.rs b/models/transformers/src/ops/mod.rs index ee4e98fa..778e4abc 100644 --- a/models/transformers/src/ops/mod.rs +++ b/models/transformers/src/ops/mod.rs @@ -13,74 +13,60 @@ pub(crate) mod utils { use nd::{Data, Order, RemoveAxis}; #[doc(hidden)] - pub fn merge( - z: &mut ArrayBase, - swap: usize, - with: usize, - ) -> NdResult> + pub fn merge( + arr: &ArrayBase, + src: usize, + tgt: usize, + ) -> NdResult> where A: Clone, + D: RemoveAxis, S: Data, - D: RemoveAxis, - E: Dimension, + D::Smaller: Dimension, + ArrayBase: Clone, { - let cur = z.raw_dim().as_array_view().to_owned(); - let indicies = (0..cur.ndim()).filter(|&i| i != swap).collect::>(); - let new_axis = cur[swap] * cur[with]; - let mut dim = cur.select(Axis(0), &indicies); - dim[with - 1] = new_axis; - - // swap the head and sequence axes - z.swap_axes(swap, with); - // reshape the qkv matrix into a smaller dimension - // z.to_shape((dim, Order::ColumnMajor)) - unimplemented!() + merger(arr, src, tgt, Order::RowMajor) } - #[doc(hidden)] - pub fn merge_simple( - z: &mut ArrayBase, - dim: E, - swap: usize, - with: usize, - ) -> NdResult> + + pub(crate) fn merger( + arr: &ArrayBase, + src: usize, + tgt: usize, + order: Order, + ) -> NdResult> where A: Clone, + D: RemoveAxis, S: Data, - D: RemoveAxis, - E: Dimension, + D::Smaller: Dimension, + ArrayBase: Clone, { - // swap the head and sequence axes - z.swap_axes(swap, with); - // reshape the qkv matrix into a smaller dimension - z.to_shape((dim, Order::ColumnMajor)) + let shape = merge_dims(arr.raw_dim(), src); + let mut head = arr.clone(); + head.swap_axes(src, tgt); + head.to_shape((shape, order)).map(|x| x.to_owned()) } - pub fn merge_heads(heads: &Array3) -> NdResult> + #[doc(hidden)] + pub fn merge_dims(dim: D, src: usize) -> D::Smaller where - A: Clone, + D: RemoveAxis, + D::Smaller: Dimension, { - let (n, seq, query) = heads.dim(); - let mut tmp = heads.clone(); - // swap the head and sequence axes - tmp.swap_axes(0, 1); - // reshape the qkv matrix into a 2d array - tmp.into_shape((seq, n * query)) - } + // create a new dimension with one less axis; initialized with zeros + let mut new_dim = ::Smaller::zeros(dim.ndim() - 1); + // create a mutable vector from the slice + let mut shape = dim.slice().to_vec(); + // multiply the last axis by the target + shape[new_dim.ndim()] *= shape[src]; + // remove the last dimension + shape.remove(src); - pub fn split_heads(param: &Array2, num_heads: usize) -> NdResult> - where - T: Clone, - { - let dim = param.shape().last().unwrap() / num_heads; - // reshape the qkv matrix into a 3d array - let mut res = param - .clone() - .into_shape((param.shape()[0], num_heads, dim))?; - // swap the sequence and head axes - res.swap_axes(0, 1); - Ok(res) + new_dim.slice_mut().copy_from_slice(&shape); + new_dim } + #[doc(hidden)] pub fn merge_batch(heads: &Array4) -> NdResult> where T: Clone, @@ -93,38 +79,29 @@ pub(crate) mod utils { tmp.into_shape((batch, seq, n * query)) } - pub fn split_batch(param: &Array3, num_heads: usize) -> NdResult> + pub fn split_heads(param: &Array2, h: usize) -> NdResult> where T: Clone, { - let dim = param.shape().last().unwrap() / num_heads; + let dim = param.shape().last().unwrap() / h; // reshape the qkv matrix into a 3d array - let mut res = - param - .clone() - .into_shape((param.shape()[0], param.shape()[1], num_heads, dim))?; + let mut res = param.clone().into_shape((param.shape()[0], h, dim))?; // swap the sequence and head axes - res.swap_axes(1, 2); + res.swap_axes(0, 1); Ok(res) } -} - -#[cfg(test)] -mod tests { - use super::*; - use ndarray::Array; - #[test] - fn reshape_ops() { - let dim_input: [usize; 3] = [2, 4, 6]; // (batch, seq, model) - let dim_split = [2, 2, 4, 3]; // (batch, heads, seq, model) - let data = Array::linspace(1., 48., 48).into_shape(dim_input).unwrap(); - - let a = split_batch(&data, 2).unwrap(); - assert_eq!(a.shape(), &dim_split); - assert_eq!(&a, &data.split(2).unwrap()); - let b = merge_batch(&a).unwrap(); - assert_eq!(b.shape(), &dim_input); - assert_eq!(&b, &data); + pub fn split_batch(param: &Array3, h: usize) -> NdResult> + where + T: Clone, + { + let dim = param.shape().last().unwrap() / h; + // reshape the qkv matrix into a 3d array + let mut res = param + .clone() + .into_shape((param.shape()[0], param.shape()[1], h, dim))?; + // swap the sequence and head axes + res.swap_axes(1, 2); + Ok(res) } } diff --git a/models/transformers/src/ops/split.rs b/models/transformers/src/ops/split.rs index 13c367e4..3a182710 100644 --- a/models/transformers/src/ops/split.rs +++ b/models/transformers/src/ops/split.rs @@ -15,16 +15,16 @@ use ndarray::ShapeError; // Ok(res) // } -pub trait Split { - type Error; +pub trait Split { + type Output; - fn split(&self, heads: usize) -> Result; + fn split(&self, heads: usize) -> Result; } -impl Split> for Array2 { - type Error = ShapeError; +impl Split for Array2 { + type Output = Array3; - fn split(&self, heads: usize) -> Result, Self::Error> { + fn split(&self, heads: usize) -> Result { let (seq, model) = self.dim(); let query = model / heads; // reshape the qkv matrix into a 3d array @@ -35,10 +35,10 @@ impl Split> for Array2 { } } -impl Split> for Array3 { - type Error = ShapeError; +impl Split for Array3 { + type Output = Array4; - fn split(&self, heads: usize) -> Result, Self::Error> { + fn split(&self, heads: usize) -> Result { let (batch, seq, model) = self.dim(); let query = model / heads; // reshape the qkv matrix into a 3d array diff --git a/models/transformers/src/primitives.rs b/models/transformers/src/primitives.rs new file mode 100644 index 00000000..96db829b --- /dev/null +++ b/models/transformers/src/primitives.rs @@ -0,0 +1,10 @@ +/* + Appellation: primitives + Contrib: FL03 +*/ +pub use self::consts::*; + +pub mod consts { + /// The default number of layers used for the encoder / decoder. + pub const N: usize = 6; +} diff --git a/models/transformers/tests/ops.rs b/models/transformers/tests/ops.rs new file mode 100644 index 00000000..c39b8efa --- /dev/null +++ b/models/transformers/tests/ops.rs @@ -0,0 +1,52 @@ +/* + Appellation: ops + Contrib: FL03 +*/ +extern crate concision_core as concision; +extern crate concision_transformers as transformers; + +use concision::linarr; +use ndarray::prelude::*; +use transformers::ops::*; + +#[test] +fn test_merge() { + let shape = (3, 4, 5); + let dout = (4, 15); + let arr = linarr::(shape.clone()).unwrap(); + let a = arr.clone().merge().unwrap(); + let b = merge(&arr, 0, 1).unwrap(); + + assert_eq!(a.dim(), dout); + assert_eq!(a.dim(), b.dim()); + assert_eq!(a, b); +} + +#[test] +fn test_merge_batch() { + let shape = (2, 3, 4, 5); + let dout = (2, 4, 15); + let arr = linarr::(shape).unwrap(); + let a = arr.merge().unwrap(); + let b = merge(&arr, 1, 2).unwrap(); + + assert_eq!(a.dim(), dout); + assert_eq!(a, b); +} + +#[test] +fn reshape_ops() { + let dim_input: [usize; 3] = [2, 4, 6]; // (batch, seq, model) + let dim_split = [2, 2, 4, 3]; // (batch, heads, seq, model) + let data = linarr::(dim_input).unwrap(); + + let a = split_batch(&data, 2).unwrap(); + let b = a.merge().unwrap(); // merge_batch(&a).unwrap(); + + assert_eq!(a.shape(), &dim_split); + assert_eq!(b.shape(), &dim_input); + assert_eq!(a, data.split(2).unwrap()); + for (i, &j) in b.indexed_iter() { + assert_eq!(j, data[i]); + } +}