Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
12 changes: 6 additions & 6 deletions .artifacts/archive/nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,17 @@ test = false
[build-dependencies]

[dependencies]
concision-neural = { path = "./neural", version = "0.1.12" }
concision-nlp = { optional = true, path = "./nlp", version = "0.1.12" }
concision-optim = { optional = true, path = "./optim", version = "0.1.12" }
concision-s4 = { optional = true, path = "./s4", version = "0.1.12" }
transformers = { optional = true, path = "./transformers", version = "0.1.12" }
concision-neural = { path = "../neural", version = "0.1.12" }
concision-nlp = { optional = true, path = "../nlp", version = "0.1.12" }
concision-optim = { optional = true, path = "../optim", version = "0.1.12" }
concision-s4 = { optional = true, path = "../s4", version = "0.1.12" }
transformers = { optional = true, path = "../transformers", version = "0.1.12" }


[dev-dependencies]
anyhow = "1"
approx = "0.5"
concision = { path = "../../concision" }
concision = { path = "../../../concision" }
ndarray = { features = ["approx-0_5", "serde-1"], version = "0.15" }

[package.metadata.docs.rs]
Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ approx = "0.5"
itertools = "0.12"
lazy_static = "1"
ndarray = { default-features = false, version = "0.15" }

num = { default-features = false, version = "0.4" }
paste = "1"
smart-default = "0.7"
strum = { default-features = false, features = ["derive"], version = "0.26" }

Expand Down
8 changes: 4 additions & 4 deletions core/src/traits/arr/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub trait Inverse {
pub trait Matmul<Rhs = Self> {
type Output;

fn matmul(&self, rhs: Rhs) -> Self::Output;
fn matmul(&self, rhs: &Rhs) -> Self::Output;
}

pub trait Matpow<Rhs = Self> {
Expand Down Expand Up @@ -56,14 +56,14 @@ where
}
}

impl<X, Y, S> Matmul<X> for S
impl<S, X, Y> Matmul<X> for S
where
S: Dot<X, Output = Y>,
{
type Output = Y;

fn matmul(&self, rhs: X) -> Self::Output {
self.dot(&rhs)
fn matmul(&self, rhs: &X) -> Self::Output {
self.dot(rhs)
}
}

Expand Down
2 changes: 1 addition & 1 deletion models/gnn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ std = [
"concision-core/std",
"ndarray/std",
"num/std",
"serde/std",
"serde?/std",
"strum/std",
]

Expand Down
2 changes: 1 addition & 1 deletion models/kan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ std = [
"concision-core/std",
"ndarray/std",
"num/std",
"serde/std",
"serde?/std",
"strum/std",
]

Expand Down
123 changes: 0 additions & 123 deletions models/transformers/Cargo copy.toml

This file was deleted.

3 changes: 2 additions & 1 deletion models/transformers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ std = [
"concision-core/std",
"ndarray/std",
"num/std",
"serde/std",
"serde?/std",
"strum/std",
]

Expand All @@ -89,6 +89,7 @@ test = true
[dependencies]
ndarray.workspace = true
num.workspace = true
paste.workspace = true
smart-default.workspace = true
strum.workspace = true

Expand Down
45 changes: 45 additions & 0 deletions models/transformers/src/attention/head.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
Appellation: head <module>
Contrib: FL03 <jo3mccain@icloud.com>
*/
use crate::params::QKVBase;
use nd::*;

pub struct AttentionHead<A = f64, S = OwnedRepr<A>, D = Ix2>
where
D: Dimension,
S: RawData<Elem = A>,
{
params: QKVBase<S, D>,
}

impl<A, S, D> AttentionHead<A, S, D>
where
D: Dimension,
S: RawData<Elem = A>,
{
pub fn from_params(params: QKVBase<S, D>) -> Self {
Self { params }
}

pub fn builder<Sh, F>(shape: Sh, builder: F) -> Self
where
F: Fn(D) -> ArrayBase<S, D>,
Sh: ShapeBuilder<Dim = D>,
{
Self::from_params(QKVBase::builder(shape, builder))
}

pub fn params(&self) -> &QKVBase<S, D> {
&self.params
}

pub fn params_mut(&mut self) -> &mut QKVBase<S, D> {
&mut self.params
}

access!(params::<q, k, v>);
fwd_builder!(new.default where A: Default, S: DataOwned);
fwd_builder!(ones.ones where A: Clone + num::One, S: DataOwned);
fwd_builder!(zeros.zeros where A: Clone + num::Zero, S: DataOwned);
}
9 changes: 9 additions & 0 deletions models/transformers/src/attention/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,12 @@
Appellation: attention <module>
Contrib: FL03 <jo3mccain@icloud.com>
*/
pub use self::head::AttentionHead;

pub(crate) mod head;

pub mod multi;

pub(crate) mod prelude {
pub use super::head::AttentionHead;
}
10 changes: 10 additions & 0 deletions models/transformers/src/attention/multi/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
Appellation: multi <module>
Contrib: FL03 <jo3mccain@icloud.com>
*/
//! # Multi-Head Attention
//!
//!
pub use self::multi_head::*;

pub(crate) mod multi_head;
6 changes: 6 additions & 0 deletions models/transformers/src/attention/multi/multi_head.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
/*
Appellation: multi_head <module>
Contrib: FL03 <jo3mccain@icloud.com>
*/

pub struct MultiHeadAttention;
19 changes: 19 additions & 0 deletions models/transformers/src/impls/impl_head.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
Appellation: impl_head <module>
Contrib: FL03 <jo3mccain@icloud.com>
*/
use crate::attention::AttentionHead;
use crate::params::QKVBase;
use nd::prelude::*;
use nd::DataOwned;

impl<A, S, D> Default for AttentionHead<A, S, D>
where
A: Default,
D: Dimension,
S: DataOwned<Elem = A>,
{
fn default() -> Self {
Self::from_params(QKVBase::default())
}
}
50 changes: 50 additions & 0 deletions models/transformers/src/impls/impl_linalg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
Appellation: impl_linalg <module>
Contrib: FL03 <jo3mccain@icloud.com>
*/
use crate::params::{QKVBase, QKV};
use concision::Matmul;
use nd::linalg::Dot;
use nd::*;

impl<A, S, T, D, E, F> Matmul<QKVBase<T, E>> for QKVBase<S, D>
where
A: LinalgScalar,
D: Dimension,
E: Dimension,
F: Dimension,
S: Data<Elem = A>,
T: Data<Elem = A>,
ArrayBase<S, D>: Dot<ArrayBase<T, E>, Output = Array<A, F>>,
{
type Output = QKV<A, F>;

fn matmul(&self, rhs: &QKVBase<T, E>) -> Self::Output {
QKVBase {
q: self.q().dot(rhs.q()),
k: self.k().dot(rhs.k()),
v: self.v().dot(rhs.v()),
}
}
}

impl<A, S, T, D, E, F> Matmul<ArrayBase<T, E>> for QKVBase<S, D>
where
A: LinalgScalar,
D: Dimension,
E: Dimension,
F: Dimension,
S: Data<Elem = A>,
T: Data<Elem = A>,
ArrayBase<S, D>: Dot<ArrayBase<T, E>, Output = Array<A, F>>,
{
type Output = QKV<A, F>;

fn matmul(&self, rhs: &ArrayBase<T, E>) -> Self::Output {
QKVBase {
q: self.q().dot(rhs),
k: self.k().dot(rhs),
v: self.v().dot(rhs),
}
}
}
Loading