diff --git a/Cargo.lock b/Cargo.lock index d70e3286..999bd37d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -468,6 +468,26 @@ dependencies = [ "syn", ] +[[package]] +name = "equator" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c35da53b5a021d2484a7cc49b2ac7f2d840f8236a286f84202369bd338d761ea" +dependencies = [ + "equator-macro", +] + +[[package]] +name = "equator-macro" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -998,6 +1018,69 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "nano-gemm" +version = "0.1.2" +source = "git+https://github.com/sarah-ek/nano-gemm#86385315356ab81c66aabf70bcbe10ad9aa3ebae" +dependencies = [ + "equator", + "nano-gemm-c32", + "nano-gemm-c64", + "nano-gemm-codegen", + "nano-gemm-core", + "nano-gemm-f32", + "nano-gemm-f64", + "num-complex", +] + +[[package]] +name = "nano-gemm-c32" +version = "0.1.0" +source = "git+https://github.com/sarah-ek/nano-gemm#86385315356ab81c66aabf70bcbe10ad9aa3ebae" +dependencies = [ + "nano-gemm-codegen", + "nano-gemm-core", + "num-complex", +] + +[[package]] +name = "nano-gemm-c64" +version = "0.1.0" +source = "git+https://github.com/sarah-ek/nano-gemm#86385315356ab81c66aabf70bcbe10ad9aa3ebae" +dependencies = [ + "nano-gemm-codegen", + "nano-gemm-core", + "num-complex", +] + +[[package]] +name = "nano-gemm-codegen" +version = "0.1.0" +source = "git+https://github.com/sarah-ek/nano-gemm#86385315356ab81c66aabf70bcbe10ad9aa3ebae" + +[[package]] +name = "nano-gemm-core" +version = "0.1.0" +source = "git+https://github.com/sarah-ek/nano-gemm#86385315356ab81c66aabf70bcbe10ad9aa3ebae" + +[[package]] +name = "nano-gemm-f32" +version = "0.1.0" +source = "git+https://github.com/sarah-ek/nano-gemm#86385315356ab81c66aabf70bcbe10ad9aa3ebae" +dependencies = [ + "nano-gemm-codegen", + "nano-gemm-core", +] + +[[package]] +name = "nano-gemm-f64" +version = "0.1.0" +source = "git+https://github.com/sarah-ek/nano-gemm#86385315356ab81c66aabf70bcbe10ad9aa3ebae" +dependencies = [ + "nano-gemm-codegen", + "nano-gemm-core", +] + [[package]] name = "ndarray" version = "0.15.6" @@ -1045,9 +1128,9 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.4" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ "num-traits", ] @@ -1064,9 +1147,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.17" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", ] @@ -1237,6 +1320,7 @@ dependencies = [ "itertools", "lz4_flex", "managed-lhapdf", + "nano-gemm", "ndarray", "ndarray-npy", "num-complex", diff --git a/pineappl/Cargo.toml b/pineappl/Cargo.toml index c279937e..383ea373 100644 --- a/pineappl/Cargo.toml +++ b/pineappl/Cargo.toml @@ -25,6 +25,7 @@ float-cmp = "0.9.0" git-version = "0.3.5" itertools = "0.10.1" lz4_flex = "0.9.2" +nano-gemm = { git = "https://github.com/sarah-ek/nano-gemm", package = "nano-gemm" } ndarray = { features = ["serde"], version = "0.15.4" } rustc-hash = "1.1.0" serde = { features = ["derive"], version = "1.0.130" } diff --git a/pineappl/src/evolution.rs b/pineappl/src/evolution.rs index 95667057..8bbd2cc3 100644 --- a/pineappl/src/evolution.rs +++ b/pineappl/src/evolution.rs @@ -10,8 +10,7 @@ use super::sparse_array3::SparseArray3; use super::subgrid::{Mu2, Subgrid, SubgridEnum}; use float_cmp::approx_eq; use itertools::Itertools; -use ndarray::linalg; -use ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView4, Axis}; +use ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView4, Axis}; use std::iter; /// Number of ULPS used to de-duplicate grid values in [`Grid::evolve_info`]. @@ -170,6 +169,39 @@ fn gluon_has_pid_zero(grid: &Grid) -> bool { && grid.pid_basis() == PidBasis::Pdg } +fn nano_gemm_mat_mul( + alpha: f64, + a: ArrayView2, + b: ArrayView2, + beta: f64, + c: &mut Array2, +) { + use nano_gemm::planless; + + let ((m, k), (_, n)) = (a.dim(), b.dim()); + + unsafe { + planless::execute_f64( + m, + n, + k, + c.as_mut_ptr(), + c.strides()[0], + c.strides()[1], + a.as_ptr(), + a.strides()[0], + a.strides()[1], + b.as_ptr(), + b.strides()[0], + b.strides()[1], + beta, + alpha, + false, + false, + ); + } +} + type Pid01IndexTuples = Vec<(usize, usize)>; type Pid01Tuples = Vec<(i32, i32)>; @@ -602,8 +634,8 @@ pub(crate) fn evolve_slice_with_two( .map(|(opa, opb)| (fk_table, opa, opb)) }, ) { - linalg::general_mat_mul(1.0, &array, &opb.t(), 0.0, &mut tmp); - linalg::general_mat_mul(factor, opa, &tmp, 1.0, fk_table); + nano_gemm_mat_mul(1.0, array.view(), opb.t(), 0.0, &mut tmp); + nano_gemm_mat_mul(factor, opa.view(), tmp.view(), 1.0, fk_table); } } }