diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e354ea5..770e205 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -254,6 +254,7 @@ jobs: CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} run: | rustup default stable + cargo publish -p neopdf_legacy cargo publish -p neopdf cargo publish -p neopdf_tmdlib cargo publish -p neopdf_capi diff --git a/CHANGELOG.md b/CHANGELOG.md index 4759734..0eef590 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Added routines to compute uncertainties from non-perturbative functions. + +### Changed + +- Separate `neopdf_legacy` to be its own crate. + ## [0.3.2] - 29/03/2026 ## [0.3.1] - 18/03/2026 diff --git a/Cargo.lock b/Cargo.lock index 1381d43..8a52e04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -143,19 +143,6 @@ dependencies = [ "tempfile", ] -[[package]] -name = "async-compression" -version = "0.4.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddb939d66e4ae03cee6091612804ba446b12878410cfa17f785f4dd67d4014e8" -dependencies = [ - "flate2", - "futures-core", - "memchr", - "pin-project-lite", - "tokio", -] - [[package]] name = "atk" version = "0.18.2" @@ -458,12 +445,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" -[[package]] -name = "cfg_aliases" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" - [[package]] name = "chrono" version = "0.4.43" @@ -1265,7 +1246,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", - "futures-sink", ] [[package]] @@ -1467,10 +1447,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", - "js-sys", "libc", "wasi 0.11.1+wasi-snapshot-preview1", - "wasm-bindgen", ] [[package]] @@ -1480,11 +1458,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", - "js-sys", "libc", "r-efi", "wasi 0.14.2+wasi-0.2.4", - "wasm-bindgen", ] [[package]] @@ -1828,23 +1804,6 @@ dependencies = [ "want", ] -[[package]] -name = "hyper-rustls" -version = "0.27.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" -dependencies = [ - "http", - "hyper", - "hyper-util", - "rustls", - "rustls-pki-types", - "tokio", - "tokio-rustls", - "tower-service", - "webpki-roots", -] - [[package]] name = "hyper-util" version = "0.1.15" @@ -2361,12 +2320,6 @@ version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" -[[package]] -name = "lru-slab" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" - [[package]] name = "lz4_flex" version = "0.11.5" @@ -2532,30 +2485,6 @@ dependencies = [ "jni-sys", ] -[[package]] -name = "neopdf" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61b0aaca320ec52a7b70f90e14b5046b194db82fcdf59d0692e539aff111e816" -dependencies = [ - "bincode", - "flate2", - "git-version", - "indicatif", - "itertools 0.13.0", - "lz4_flex", - "ndarray", - "ninterp", - "rayon", - "regex", - "reqwest", - "serde", - "serde_yaml", - "tar", - "tempfile", - "thiserror 1.0.69", -] - [[package]] name = "neopdf" version = "0.3.2" @@ -2568,16 +2497,16 @@ dependencies = [ "itertools 0.13.0", "lz4_flex", "ndarray", - "neopdf 0.2.0", + "neopdf_legacy", "ninterp", "rayon", "regex", - "reqwest", "serde", "serde_yaml", "tar", "tempfile", "thiserror 1.0.69", + "ureq", ] [[package]] @@ -2586,7 +2515,7 @@ version = "0.3.2" dependencies = [ "cbindgen", "ndarray", - "neopdf 0.3.2", + "neopdf", ] [[package]] @@ -2597,7 +2526,7 @@ dependencies = [ "assert_fs", "clap 4.5.41", "ndarray", - "neopdf 0.3.2", + "neopdf", "neopdf_tmdlib", "predicates", "serde", @@ -2610,7 +2539,7 @@ dependencies = [ name = "neopdf_gui" version = "0.3.2" dependencies = [ - "neopdf 0.3.2", + "neopdf", "rayon", "serde", "serde_json", @@ -2620,12 +2549,34 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "neopdf_legacy" +version = "0.3.2" +dependencies = [ + "bincode", + "flate2", + "git-version", + "indicatif", + "itertools 0.13.0", + "lz4_flex", + "ndarray", + "ninterp", + "rayon", + "regex", + "serde", + "serde_yaml", + "tar", + "tempfile", + "thiserror 1.0.69", + "ureq", +] + [[package]] name = "neopdf_pyapi" version = "0.3.2" dependencies = [ "ndarray", - "neopdf 0.3.2", + "neopdf", "numpy", "pyo3", "thiserror 1.0.69", @@ -2646,7 +2597,7 @@ name = "neopdf_wolfram" version = "0.3.2" dependencies = [ "lazy_static", - "neopdf 0.3.2", + "neopdf", "parking_lot", "wolfram-library-link", ] @@ -3501,61 +3452,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "quinn" -version = "0.11.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "626214629cda6781b6dc1d316ba307189c85ba657213ce642d9c77670f8202c8" -dependencies = [ - "bytes", - "cfg_aliases", - "pin-project-lite", - "quinn-proto", - "quinn-udp", - "rustc-hash", - "rustls", - "socket2", - "thiserror 2.0.12", - "tokio", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-proto" -version = "0.11.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" -dependencies = [ - "bytes", - "getrandom 0.3.3", - "lru-slab", - "rand 0.9.2", - "ring", - "rustc-hash", - "rustls", - "rustls-pki-types", - "slab", - "thiserror 2.0.12", - "tinyvec", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-udp" -version = "0.5.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcebb1209ee276352ef14ff8732e24cc2b02bbac986cd74a4c81bcb2f9881970" -dependencies = [ - "cfg_aliases", - "libc", - "once_cell", - "socket2", - "tracing", - "windows-sys 0.59.0", -] - [[package]] name = "quote" version = "1.0.40" @@ -3596,16 +3492,6 @@ dependencies = [ "rand_core 0.6.4", ] -[[package]] -name = "rand" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" -dependencies = [ - "rand_chacha 0.9.0", - "rand_core 0.9.3", -] - [[package]] name = "rand_chacha" version = "0.2.2" @@ -3626,16 +3512,6 @@ dependencies = [ "rand_core 0.6.4", ] -[[package]] -name = "rand_chacha" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" -dependencies = [ - "ppv-lite86", - "rand_core 0.9.3", -] - [[package]] name = "rand_core" version = "0.5.1" @@ -3654,15 +3530,6 @@ dependencies = [ "getrandom 0.2.16", ] -[[package]] -name = "rand_core" -version = "0.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" -dependencies = [ - "getrandom 0.3.3", -] - [[package]] name = "rand_hc" version = "0.2.0" @@ -3788,31 +3655,24 @@ version = "0.12.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cbc931937e6ca3a06e3b6c0aa7841849b160a90351d6ab467a8b9b9959767531" dependencies = [ - "async-compression", "base64 0.22.1", "bytes", - "futures-channel", "futures-core", "futures-util", "http", "http-body", "http-body-util", "hyper", - "hyper-rustls", "hyper-util", "js-sys", "log", "percent-encoding", "pin-project-lite", - "quinn", - "rustls", - "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", "tokio", - "tokio-rustls", "tokio-util", "tower", "tower-http", @@ -3822,7 +3682,6 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", ] [[package]] @@ -3916,6 +3775,7 @@ version = "0.23.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", @@ -3930,7 +3790,6 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" dependencies = [ - "web-time", "zeroize", ] @@ -4958,21 +4817,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "tinyvec" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - [[package]] name = "tokio" version = "1.46.1" @@ -4990,16 +4834,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "tokio-rustls" -version = "0.26.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" -dependencies = [ - "rustls", - "tokio", -] - [[package]] name = "tokio-util" version = "0.7.15" @@ -5323,6 +5157,22 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "url", + "webpki-roots 0.26.11", +] + [[package]] name = "url" version = "2.5.4" @@ -5606,6 +5456,15 @@ dependencies = [ "system-deps", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.2", +] + [[package]] name = "webpki-roots" version = "1.0.2" diff --git a/Cargo.toml b/Cargo.toml index 455f31d..b099fc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["neopdf", "neopdf_capi", "neopdf_cli", "neopdf_gui", "neopdf_pyapi", "neopdf_tmdlib", "neopdf_wolfram"] +members = ["neopdf", "neopdf_capi", "neopdf_cli", "neopdf_gui", "neopdf_legacy", "neopdf_pyapi", "neopdf_tmdlib", "neopdf_wolfram"] default-members = ["neopdf", "neopdf_capi", "neopdf_cli", "neopdf_pyapi"] resolver = "2" @@ -24,7 +24,7 @@ ndarray = { version = "0.16.1", features = ["serde"] } ninterp = "0.7.3" rayon = "1.5" regex = "1.11.1" -reqwest = { version = "0.12.22", features = ["blocking", "gzip", "rustls-tls", "stream"], default-features = false } +ureq = { version = "2", features = ["tls"] } serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9" tar = "0.4.44" @@ -50,6 +50,7 @@ predicates = "3.1.3" # Internal crates neopdf = { path = "./neopdf", version = "0.3.2" } +neopdf_legacy = { path = "./neopdf_legacy", version = "0.3.2" } [workspace.lints.clippy] all = { level = "warn", priority = -1 } diff --git a/neopdf/Cargo.toml b/neopdf/Cargo.toml index b54b9e2..8e3d2d5 100644 --- a/neopdf/Cargo.toml +++ b/neopdf/Cargo.toml @@ -23,7 +23,7 @@ thiserror.workspace = true lz4_flex.workspace = true bincode.workspace = true flate2.workspace = true -reqwest.workspace = true +ureq.workspace = true tar.workspace = true itertools.workspace = true regex.workspace = true @@ -31,7 +31,7 @@ git-version.workspace = true indicatif.workspace = true # For backward compatibility with v0.2.0 format -neopdf_legacy = { package = "neopdf", version = "0.2.0" } +neopdf_legacy.workspace = true [dev-dependencies] criterion.workspace = true diff --git a/neopdf/src/manage.rs b/neopdf/src/manage.rs index 0504029..8ace9ce 100644 --- a/neopdf/src/manage.rs +++ b/neopdf/src/manage.rs @@ -84,25 +84,10 @@ impl ManageData { ); println!("Downloading PDF set from: {}", url); - let response = reqwest::blocking::Client::builder() - .timeout(None) - .build()? - .get(&url) - .send()?; - - if !response.status().is_success() { - return Err(format!( - "Failed to download PDF set '{}': HTTP {}", - self.set_name, - response.status() - ) - .into()); - } + let response = ureq::get(&url).call()?; let total_size = response - .headers() - .get(reqwest::header::CONTENT_LENGTH) - .and_then(|v| v.to_str().ok()) + .header("content-length") .and_then(|s| s.parse::().ok()) .unwrap_or(0); @@ -112,7 +97,7 @@ impl ManageData { .progress_chars("=>-")); let mut response_bytes = Vec::new(); - let mut decorated_response = pb.wrap_read(response); + let mut decorated_response = pb.wrap_read(response.into_reader()); decorated_response.read_to_end(&mut response_bytes)?; let tar = GzDecoder::new(&response_bytes[..]); diff --git a/neopdf/src/utils.rs b/neopdf/src/utils.rs index 461834e..67d1d17 100644 --- a/neopdf/src/utils.rs +++ b/neopdf/src/utils.rs @@ -14,9 +14,10 @@ const LHAPDF_INDEX_URL: &str = "https://lhapdfsets.web.cern.ch/current/pdfsets.i /// For example, if `NNPDF40_nnlo_as_01180` has base ID 331100, then LHAID 331103 /// maps to `("NNPDF40_nnlo_as_01180", 3)`. pub(crate) fn lookup_lhaid(lhaid: u32) -> Result<(String, usize), String> { - let text = reqwest::blocking::get(LHAPDF_INDEX_URL) + let text = ureq::get(LHAPDF_INDEX_URL) + .call() .map_err(|e| format!("Failed to fetch pdfsets.index: {e}"))? - .text() + .into_string() .map_err(|e| format!("Failed to read pdfsets.index response: {e}"))?; let mut entries: Vec<(u32, String)> = Vec::new(); diff --git a/neopdf_legacy/Cargo.toml b/neopdf_legacy/Cargo.toml new file mode 100644 index 0000000..6f7921e --- /dev/null +++ b/neopdf_legacy/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "neopdf_legacy" +description = "neopdf v0.2.0 for backward-compatible file reading" + +categories.workspace = true +edition.workspace = true +keywords.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true +version.workspace = true + +[dependencies] +bincode.workspace = true +flate2.workspace = true +git-version.workspace = true +indicatif.workspace = true +itertools.workspace = true +lz4_flex.workspace = true +ndarray.workspace = true +ninterp.workspace = true +rayon.workspace = true +regex.workspace = true +serde.workspace = true +serde_yaml.workspace = true +tar.workspace = true +tempfile.workspace = true +thiserror.workspace = true +ureq.workspace = true diff --git a/neopdf_legacy/src/alphas.rs b/neopdf_legacy/src/alphas.rs new file mode 100644 index 0000000..1a4fd30 --- /dev/null +++ b/neopdf_legacy/src/alphas.rs @@ -0,0 +1,386 @@ +//! This module provides implementations for calculating the strong coupling constant. +//! +//! It includes support for different calculation methods, such as analytic formulas and +//! interpolation from tabulated values, mirroring the functionality available in `LHAPDF`. + +use ninterp::interpolator::Extrapolate; +use ninterp::prelude::*; +use std::collections::HashMap; +use thiserror::Error; + +use super::metadata::MetaData; +use super::strategy::AlphaSCubicInterpolation; + +/// Errors that can occur during the analytical computations of `alpha_s`. +#[derive(Debug, Error)] +pub enum Error { + /// Error indicating that no `Lambda_QCD` value is defined for the given `nf`. + #[error("No subgrid LambdaQCD for nf={nf}")] + LambdaQCDValueNotFound { + /// The number of active flavors. + nf: u32, + }, + /// Error indicating that zero active flavor is not accepted. + #[error("Invalid zero nf value")] + NfZeroValueError, + /// Error indicating that the order to compute the Beta function is not supported. + #[error("Invalid order value to compute Beta o={order}")] + BetaOrderValueError { + /// Order to compute the Beta function. + order: u32, + }, +} + +/// Enum representing the different methods for alpha_s calculation. +pub enum AlphaS { + Analytic(AlphaSAnalytic), + Interpol(AlphaSInterpol), +} + +impl AlphaS { + /// Creates a new `AlphaS` calculator from PDF metadata. + pub fn from_metadata(meta: &MetaData) -> Result { + // TODO: Use `meta.alphas_type` for the logics. + if meta.alphas_vals.is_empty() { + Ok(AlphaS::Analytic(AlphaSAnalytic::from_metadata(meta)?)) + } else { + Ok(AlphaS::Interpol(AlphaSInterpol::from_metadata(meta)?)) + } + } + + /// Calculates the strong coupling `alpha_s` at a given `Q^2`. + pub fn alphas_q2(&self, q2: f64) -> f64 { + match self { + AlphaS::Analytic(analytic) => analytic.alphas_q2(q2), + AlphaS::Interpol(interpol) => interpol.alphas_q2(q2), + } + } +} + +/// Strong coupling calculator using the analytic formulas. +pub struct AlphaSAnalytic { + qcd_order: u32, + fl_scheme: String, + lambda_maps: HashMap, + mc_sq: f64, + mb_sq: f64, + mt_sq: f64, + num_fl: u32, +} + +impl AlphaSAnalytic { + pub fn from_metadata(meta: &MetaData) -> Result { + let mut lambda_maps = HashMap::new(); + // TODO: decide what to do about these hardcoded values. + lambda_maps.insert(3, 0.339); + lambda_maps.insert(4, 0.296); + lambda_maps.insert(5, 0.213); + + let alphas_order_qcd = if meta.alphas_order_qcd == 0 { + meta.order_qcd + } else { + meta.alphas_order_qcd + }; + + Ok(Self { + qcd_order: alphas_order_qcd, + lambda_maps, + mc_sq: meta.m_charm * meta.m_charm, + mb_sq: meta.m_bottom * meta.m_bottom, + mt_sq: meta.m_top * meta.m_top, + num_fl: meta.number_flavors, + fl_scheme: meta.flavor_scheme.clone(), + }) + } + + fn number_flavors_q2(&self, q2: f64) -> u32 { + match () { + _ if self.fl_scheme.to_uppercase() == "FIXED" => self.num_fl, + _ if q2 > self.mt_sq && self.mt_sq > 0.0 => 6, + _ if q2 > self.mb_sq && self.mb_sq > 0.0 => 5, + _ if q2 > self.mc_sq && self.mc_sq > 0.0 => 4, + _ => 3, + } + } + + fn lambda_qcd(&self, nf: u32) -> Result { + // NOTE: This is better be checked using `alphas_type`. + match self.fl_scheme.to_uppercase().as_str() { + "FIXED" => match self.lambda_maps.get(&self.num_fl) { + Some(lambda_value) => Ok(*lambda_value), + None => Err(Error::LambdaQCDValueNotFound { nf: self.num_fl }), + }, + _ => { + if nf == 0 { + return Err(Error::NfZeroValueError); + } + match self.lambda_maps.get(&nf) { + Some(lambda_value) => Ok(*lambda_value), + None => self.lambda_qcd(nf - 1), + } + } + } + } + + fn betas(&self, bto: u32, nf: u32) -> Result { + // Copied from https://gitlab.com/hepcedar/lhapdf/-/blob/main/src/AlphaS.cc + let nf = nf as f64; + let (nf2, nf3, nf4) = (nf * nf, nf * nf * nf, nf * nf * nf * nf); + match bto { + 0 => Ok(0.875352187 - 0.053051647 * nf), + 1 => Ok(0.6459225457 - 0.0802126037 * nf), + 2 => Ok(0.719864327 - 0.140904490 * nf + 0.00303291339 * nf2), + 3 => Ok(1.172686 - 0.2785458 * nf + 0.01624467 * nf2 + 0.0000601247 * nf3), + 4 => Ok(1.714138 - 0.5940794 * nf + 0.05607482 * nf2 + - 0.0007380571 * nf3 + - 0.00000587968 * nf4), + _ => Err(Error::BetaOrderValueError { order: bto }), + } + } + + /// Calculates alpha_s(Q2) using the analytic running formula. + pub fn alphas_q2(&self, q2: f64) -> f64 { + // Copied from https://gitlab.com/hepcedar/lhapdf/-/blob/main/src/AlphaS_Analytic.cc + let nf = self.number_flavors_q2(q2); + let lambda_qcd = self.lambda_qcd(nf).unwrap(); + + if q2 <= lambda_qcd * lambda_qcd { + return f64::INFINITY; + } + + let lnx = (q2 / (lambda_qcd * lambda_qcd)).ln(); + let (lnlnx, lnlnx2, lnlnx3) = { + let lnlnx = lnx.ln(); + (lnlnx, lnlnx * lnlnx, lnlnx * lnlnx * lnlnx) + }; + let y = 1.0 / lnx; + + let beta0 = self.betas(0, nf).unwrap(); + let beta1 = self.betas(1, nf).unwrap(); + let (beta02, beta12) = (beta0 * beta0, beta1 * beta1); + let prefac = 1.0 / beta0; + let mut tmp = 1.0; + + if self.qcd_order == 0 { + return 0.118; // _alpha_mz reference value + } + + if self.qcd_order > 1 { + let a_1 = beta1 * lnlnx / beta02; + tmp -= a_1 * y; + } + + if self.qcd_order > 2 { + let beta2 = self.betas(2, nf).unwrap(); + + let prefac_b = beta12 / (beta02 * beta02); + let a_20 = lnlnx2 - lnlnx; + let a_21 = beta2 * beta0 / beta12; + let a_22 = 1.0; + tmp += prefac_b * y * y * (a_20 + a_21 - a_22); + } + + if self.qcd_order > 3 { + let beta2 = self.betas(2, nf).unwrap(); + let beta3 = self.betas(3, nf).unwrap(); + + let prefac_c = 1. / (beta02 * beta02 * beta02); + let a_30 = (beta12 * beta1) * (lnlnx3 - (5.0 / 2.0) * lnlnx2 - 2.0 * lnlnx + 0.5); + let a_31 = 3.0 * beta0 * beta1 * beta2 * lnlnx; + let a_32 = 0.5 * beta02 * beta3; + tmp -= prefac_c * y * y * y * (a_30 + a_31 - a_32); + } + + prefac * y * tmp + } +} + +/// Strong coupling calculator using interpolation. +pub struct AlphaSInterpol { + interpolator: Interp1DOwned, +} + +impl AlphaSInterpol { + pub fn from_metadata(meta: &MetaData) -> Result { + let (q_values, alphas_vals): (Vec<_>, Vec<_>) = meta + .alphas_q_values + .iter() + .zip(&meta.alphas_vals) + .enumerate() + .filter(|(i, (&q, _))| *i == 0 || q != meta.alphas_q_values[i - 1]) + .map(|(_, (&q, &alpha))| (q, alpha)) + .unzip(); + + let q2_values: Vec = q_values.iter().map(|&q| (q * q).ln()).collect(); + + let interpolator = Interp1D::new( + q2_values.into(), + alphas_vals.into(), + AlphaSCubicInterpolation, + Extrapolate::Error, + ) + .map_err(|e| e.to_string())?; + + Ok(Self { interpolator }) + } + + pub fn alphas_q2(&self, q2: f64) -> f64 { + self.interpolator.interpolate(&[q2.ln()]).unwrap_or(0.0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::metadata::{InterpolatorType, MetaData, MetaDataV1, SetType}; + + fn mock_metadata() -> MetaData { + MetaData::new_v1(MetaDataV1 { + set_desc: "Test PDF".to_string(), + set_index: 0, + num_members: 1, + x_min: 1e-9, + x_max: 1.0, + q_min: 1.0, + q_max: 1000.0, + flavors: vec![], + format: "test".to_string(), + alphas_q_values: vec![], + alphas_vals: vec![], + polarised: false, + set_type: SetType::SpaceLike, + interpolator_type: InterpolatorType::LogBicubic, + error_type: "test".to_string(), + hadron_pid: 2212, + git_version: "test".to_string(), + code_version: "test".to_string(), + flavor_scheme: "variable".to_string(), + order_qcd: 2, + alphas_order_qcd: 2, + m_w: 80.385, + m_z: 91.1876, + m_up: 0.0, + m_down: 0.0, + m_strange: 0.0, + m_charm: 1.4, + m_bottom: 4.75, + m_top: 173.0, + alphas_type: "analytic".to_string(), + number_flavors: 5, + }) + } + + #[test] + fn test_alphas_analytic_order_zero() { + let mut meta = mock_metadata(); + meta.alphas_order_qcd = 0; + meta.order_qcd = 0; + let analytic = AlphaSAnalytic::from_metadata(&meta).unwrap(); + assert_eq!(analytic.alphas_q2(100.0), 0.118); + } + + #[test] + fn test_number_flavors_q2() { + let meta = mock_metadata(); + let analytic = AlphaSAnalytic::from_metadata(&meta).unwrap(); + + assert_eq!(analytic.number_flavors_q2(1.0), 3); + assert_eq!(analytic.number_flavors_q2(2.0), 4); + assert_eq!(analytic.number_flavors_q2(25.0), 5); + assert_eq!(analytic.number_flavors_q2(30000.0), 6); + + let mut meta_fixed = mock_metadata(); + meta_fixed.flavor_scheme = "FIXED".to_string(); + meta_fixed.number_flavors = 4; + let analytic_fixed = AlphaSAnalytic::from_metadata(&meta_fixed).unwrap(); + assert_eq!(analytic_fixed.number_flavors_q2(1.0), 4); + assert_eq!(analytic_fixed.number_flavors_q2(30000.0), 4); + } + + #[test] + fn test_lambda_qcd() { + let meta = mock_metadata(); + let analytic = AlphaSAnalytic::from_metadata(&meta).unwrap(); + + assert_eq!(analytic.lambda_qcd(3).unwrap(), 0.339); + assert_eq!(analytic.lambda_qcd(4).unwrap(), 0.296); + assert_eq!(analytic.lambda_qcd(5).unwrap(), 0.213); + assert_eq!(analytic.lambda_qcd(6).unwrap(), 0.213); // falls back to nf-1 + assert!(analytic.lambda_qcd(0).is_err()); + } + + #[test] + fn test_lambda_qcd_fixed_unknown_nf() { + let mut meta = mock_metadata(); + meta.flavor_scheme = "FIXED".to_string(); + meta.number_flavors = 9; + let analytic = AlphaSAnalytic::from_metadata(&meta).unwrap(); + assert!(analytic.lambda_qcd(5).is_err()); + } + + #[test] + fn test_betas() { + let meta = mock_metadata(); + let analytic = AlphaSAnalytic::from_metadata(&meta).unwrap(); + + let b0 = analytic.betas(0, 5).unwrap(); + assert!((b0 - 0.610093952).abs() < 1e-9); + for order in 0..=4u32 { + assert!(analytic.betas(order, 5).is_ok()); + } + assert!(analytic.betas(5, 5).is_err()); + } + + #[test] + fn test_alphas_q2_analytic_values() { + let meta = mock_metadata(); + let analytic = AlphaSAnalytic::from_metadata(&meta).unwrap(); + + let as_q2 = analytic.alphas_q2(100.0); + assert!(as_q2 > 0.0 && as_q2 < 1.0); + assert!(analytic.alphas_q2(1000.0) < as_q2); + } + + #[test] + fn test_alphas_q2_all_orders() { + for order in 1..=4u32 { + let mut meta = mock_metadata(); + meta.alphas_order_qcd = order; + let analytic = AlphaSAnalytic::from_metadata(&meta).unwrap(); + assert!(analytic.alphas_q2(100.0) > 0.0); + } + } + + #[test] + fn test_alphas_q2_below_lambda() { + let mut meta = mock_metadata(); + meta.alphas_order_qcd = 1; + let analytic = AlphaSAnalytic::from_metadata(&meta).unwrap(); + assert!(analytic.alphas_q2(1e-10).is_infinite()); + } + + #[test] + fn test_alphas_order_defaults_to_order_qcd() { + let mut meta = mock_metadata(); + meta.alphas_order_qcd = 0; + meta.order_qcd = 3; + let analytic = AlphaSAnalytic::from_metadata(&meta).unwrap(); + assert!(analytic.alphas_q2(100.0) > 0.0); + } + + #[test] + fn test_alphas_from_metadata_analytic() { + let meta = mock_metadata(); + let alphas = AlphaS::from_metadata(&meta).unwrap(); + assert!(alphas.alphas_q2(100.0) > 0.0); + } + + #[test] + fn test_alphas_from_metadata_interpol() { + let mut meta = mock_metadata(); + meta.alphas_q_values = vec![1.0, 10.0, 100.0]; + meta.alphas_vals = vec![0.5, 0.3, 0.118]; + let alphas = AlphaS::from_metadata(&meta).unwrap(); + assert!(alphas.alphas_q2(100.0) > 0.0); + } +} diff --git a/neopdf_legacy/src/converter.rs b/neopdf_legacy/src/converter.rs new file mode 100644 index 0000000..824bc22 --- /dev/null +++ b/neopdf_legacy/src/converter.rs @@ -0,0 +1,308 @@ +//! This module provides utilities for converting LHAPDF sets to the NeoPDF format and for +//! combining multiple nuclear PDF sets into a single NeoPDF file. +//_! +//! Main functions: +//! - `convert_lhapdf`: Converts an LHAPDF set to NeoPDF format and writes it to disk. +//! - `combine_lhapdf_npdfs`: Combines several nuclear PDF sets (with different nucleon +//! numbers) into a single NeoPDF file with explicit A dependence. +use ndarray::{concatenate, Array1, Axis}; +use regex::Regex; + +use super::gridpdf::GridArray; +use super::metadata::{InterpolatorType, MetaData}; +use super::parser::LhapdfSet; +use super::subgrid::{ParamRange, SubGrid}; +use super::writer::GridArrayCollection; + +/// Converts an LHAPDF set to the NeoPDF format and writes it to disk. +/// +/// # Arguments +/// +/// * `pdf_name` - The name of the LHAPDF set (e.g., "NNPDF40_nnlo_as_01180"). +/// * `output_path` - The path to the output NeoPDF file. +/// +/// # Errors +/// +/// Returns an error if reading or writing fails. +pub fn convert_lhapdf>( + pdf_name: &str, + output_path: P, +) -> Result<(), Box> { + let lhapdf_set = LhapdfSet::new(pdf_name); + let members = lhapdf_set.members(); + if members.is_empty() { + return Err("No members found in the LHAPDF set".into()); + } + + // All members share the same metadata + let metadata = &members[0].0.clone(); + let grids: Vec<&GridArray> = members + .iter() + .map(|(_meta, knot_array)| knot_array) + .collect(); + + GridArrayCollection::compress(&grids, metadata, output_path)?; + Ok(()) +} + +/// Combines a list of nuclear PDF sets (differing in nucleon number A) into a single NeoPDF +/// file with explicit A dependence. +/// +/// # Arguments +/// * `pdf_names` - List of PDF set names (each with a different A). +/// * `output_path` - Output NeoPDF file path. +/// +/// # Errors +/// Returns an error if loading or writing fails, or if the sets are not compatible. +pub fn combine_lhapdf_npdfs>( + pdf_names: &[&str], + output_path: P, +) -> Result<(), Box> { + if pdf_names.is_empty() { + return Err("No PDF set names provided".into()); + } + + // Regexes to extract A from the PDF set name + let re_nnpdf = Regex::new(r"_A(\d+)").unwrap(); + let re_ncteq = Regex::new(r"_(\d+)_(\d+)$").unwrap(); + let re_epps = Regex::new(r"[a-zA-Z]+(\d+)$").unwrap(); + let mut a_values = Vec::new(); + let mut all_members: Vec> = Vec::new(); + + for &pdf_name in pdf_names { + let a = if let Some(cap) = re_nnpdf.captures(pdf_name) { + cap[1].parse::().unwrap() + } else if let Some(cap) = re_ncteq.captures(pdf_name) { + cap[1].parse::().unwrap() + } else if let Some(cap) = re_epps.captures(pdf_name) { + cap[1].parse::().unwrap() + } else if pdf_name.ends_with("_p") { + 1.0 // proton + } else { + return Err(format!("Could not extract A from PDF name: {}", pdf_name).into()); + }; + a_values.push(a); + let set = LhapdfSet::new(pdf_name); + let members = set.members(); + if members.is_empty() { + return Err(format!("No members found in set: {}", pdf_name).into()); + } + all_members.push(members); + } + + let nucleons_range = ParamRange::new( + *a_values + .iter() + .min_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(), + *a_values + .iter() + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(), + ); + + // Check all sets have the same number of members + let num_members = all_members[0].len(); + if !all_members.iter().all(|v| v.len() == num_members) { + return Err("All sets must have the same number of members".into()); + } + + // For each member index, combine the corresponding member from each set along the A dimension + let mut combined_grids = Vec::with_capacity(num_members); + let mut meta = all_members[0][0].0.clone(); + meta.set_desc = format!("Combined nuclear PDFs: {}", pdf_names.join(", ")); + meta.num_members = num_members as u32; + meta.interpolator_type = InterpolatorType::LogTricubic; + + for member_idx in 0..num_members { + let member_arrays: Vec<&GridArray> = all_members.iter().map(|v| &v[member_idx].1).collect(); + + let pids = member_arrays[0].pids.clone(); + let num_subgrids = member_arrays[0].subgrids.len(); + if !member_arrays + .iter() + .all(|ga| ga.pids == pids && ga.subgrids.len() == num_subgrids) + { + return Err("All sets must have the same flavors and subgrid structure".into()); + } + + // For each subgrid, stack along the A dimension + let mut combined_subgrids = Vec::with_capacity(num_subgrids); + for subgrid_idx in 0..num_subgrids { + // For each set, get the subgrid + let subgrids: Vec<&SubGrid> = member_arrays + .iter() + .map(|ga| &ga.subgrids[subgrid_idx]) + .collect(); + + // Check x, q2, alphas shapes match + let xs = &subgrids[0].xs; + let q2s = &subgrids[0].q2s; + let kts = &subgrids[0].kts; + let alphas = &subgrids[0].alphas; + if !subgrids + .iter() + .all(|sg| sg.xs == *xs && sg.q2s == *q2s && sg.alphas == *alphas && sg.kts == *kts) + { + return Err("All sets must have the same x, q2, kT, and alphas grids".into()); + } + + // Concatenate along the nucleons axis to get [nucleons=pdf_names.len(), ...] + let grid_views: Vec<_> = subgrids.iter().map(|sg| sg.grid.view()).collect(); + let concatenated = concatenate(Axis(0), &grid_views.to_vec())?; + let nucleons = Array1::from(a_values.clone()); + let new_subgrid = SubGrid { + xs: xs.clone(), + q2s: q2s.clone(), + kts: kts.clone(), + grid: concatenated, + nucleons, + alphas: alphas.clone(), + nucleons_range, + alphas_range: subgrids[0].alphas_range, + kt_range: subgrids[0].kt_range, + x_range: subgrids[0].x_range, + q2_range: subgrids[0].q2_range, + }; + combined_subgrids.push(new_subgrid); + } + let combined_grid = GridArray { + pids: pids.clone(), + subgrids: combined_subgrids, + }; + combined_grids.push(combined_grid); + } + + let combined_grids: Vec<&GridArray> = combined_grids.iter().collect(); + GridArrayCollection::compress(&combined_grids, &meta, output_path)?; + Ok(()) +} + +/// Combines a list of PDF sets (differing in alpha_s) into a single NeoPDF file with explicit +/// `alpha_s` dependence. +/// +/// # Arguments +/// * `pdf_names` - List of PDF set names (each with a different alpha_s). +/// * `output_path` - Output NeoPDF file path. +/// +/// # Errors +/// Returns an error if loading or writing fails, or if the sets are not compatible. +pub fn combine_lhapdf_alphas>( + pdf_names: &[&str], + output_path: P, +) -> Result<(), Box> { + if pdf_names.is_empty() { + return Err("No PDF set names provided".into()); + } + + // Regexes to extract alpha_s from the PDF set name + let re_nnpdf_ct = Regex::new(r"_as_(\d+)_?").unwrap(); + let re_msht = Regex::new(r"_as(\d+)").unwrap(); + let mut alphas_values = Vec::new(); + let mut all_members: Vec> = Vec::new(); + + for &pdf_name in pdf_names { + let alphas = if let Some(cap) = re_nnpdf_ct.captures(pdf_name) { + cap[1].parse::().unwrap() / 10000.0 + } else if let Some(cap) = re_msht.captures(pdf_name) { + cap[1].parse::().unwrap() / 1000.0 + } else { + return Err(format!("Could not extract alpha_s from PDF name: {}", pdf_name).into()); + }; + alphas_values.push(alphas); + let set = LhapdfSet::new(pdf_name); + let members = set.members(); + if members.is_empty() { + return Err(format!("No members found in set: {}", pdf_name).into()); + } + all_members.push(members); + } + + let alphas_range = ParamRange::new( + *alphas_values + .iter() + .min_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(), + *alphas_values + .iter() + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(), + ); + + // Check all sets have the same number of members + let num_members = all_members[0].len(); + if !all_members.iter().all(|v| v.len() == num_members) { + return Err("All sets must have the same number of members".into()); + } + + // For each member index, combine the corresponding member from each set along the alpha_s dimension + let mut combined_grids = Vec::with_capacity(num_members); + let mut meta = all_members[0][0].0.clone(); + meta.set_desc = format!("Combined alpha_s PDFs: {}", pdf_names.join(", ")); + meta.num_members = num_members as u32; + meta.interpolator_type = InterpolatorType::LogTricubic; + + for member_idx in 0..num_members { + // For each set, get the GridArray for this member + let member_arrays: Vec<&GridArray> = all_members.iter().map(|v| &v[member_idx].1).collect(); + + // Assume all have the same pids and subgrid structure + let pids = member_arrays[0].pids.clone(); + let num_subgrids = member_arrays[0].subgrids.len(); + if !member_arrays + .iter() + .all(|ga| ga.pids == pids && ga.subgrids.len() == num_subgrids) + { + return Err("All sets must have the same flavors and subgrid structure".into()); + } + + // For each subgrid, stack along the alpha_s dimension + let mut combined_subgrids = Vec::with_capacity(num_subgrids); + for subgrid_idx in 0..num_subgrids { + // For each set, get the subgrid + let subgrids: Vec<&SubGrid> = member_arrays + .iter() + .map(|ga| &ga.subgrids[subgrid_idx]) + .collect(); + + // Check x, q2, nucleons shapes match + let xs = &subgrids[0].xs; + let q2s = &subgrids[0].q2s; + let kts = &subgrids[0].kts; + let nucleons = &subgrids[0].nucleons; + if !subgrids.iter().all(|sg| { + sg.xs == *xs && sg.q2s == *q2s && sg.nucleons == *nucleons && sg.kts == *kts + }) { + return Err("All sets must have the same x, q2, kT, and nucleons grids".into()); + } + + // Concatenate along the alphas axis to get [..., alphas=pdf_names.len(), ...] + let grid_views: Vec<_> = subgrids.iter().map(|sg| sg.grid.view()).collect(); + let concatenated = concatenate(Axis(1), &grid_views.to_vec())?; + let alphas = Array1::from(alphas_values.clone()); + let new_subgrid = SubGrid { + xs: xs.clone(), + q2s: q2s.clone(), + kts: kts.clone(), + grid: concatenated, + nucleons: nucleons.clone(), + alphas, + nucleons_range: subgrids[0].nucleons_range, + alphas_range, + kt_range: subgrids[0].kt_range, + x_range: subgrids[0].x_range, + q2_range: subgrids[0].q2_range, + }; + combined_subgrids.push(new_subgrid); + } + let combined_grid = GridArray { + pids: pids.clone(), + subgrids: combined_subgrids, + }; + combined_grids.push(combined_grid); + } + + let combined_grids: Vec<&GridArray> = combined_grids.iter().collect(); + GridArrayCollection::compress(&combined_grids, &meta, output_path)?; + Ok(()) +} diff --git a/neopdf_legacy/src/gridpdf.rs b/neopdf_legacy/src/gridpdf.rs new file mode 100644 index 0000000..9f27ad1 --- /dev/null +++ b/neopdf_legacy/src/gridpdf.rs @@ -0,0 +1,629 @@ +//! This module defines the main PDF grid interface and data structures for handling PDF grid data. +//! +//! # Contents +//! +//! - [`GridPDF`]: High-level interface for PDF grid interpolation and metadata access. +//! - [`GridArray`]: Stores the full set of subgrids and flavor IDs. + +use core::panic; +use ndarray::{Array1, Array2}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use thiserror::Error; + +use super::alphas::AlphaS; +use super::interpolator::{DynInterpolator, InterpolatorFactory}; +use super::metadata::{InterpolatorType, MetaData}; +use super::parser::SubgridData; +use super::subgrid::{ParamRange, RangeParameters, SubGrid}; + +/// Errors that can occur during PDF grid operations. +#[derive(Debug, Error)] +pub enum Error { + /// Error indicating that no suitable subgrid was found for the given `x` and `q2` values. + #[error("No subgrid found for x={x}, q2={q2}")] + SubgridNotFound { + /// The momentum fraction `x` value. + x: f64, + /// The energy scale squared `q2` value. + q2: f64, + }, + /// Error indicating invalid interpolation parameters, with a descriptive message. + #[error("Invalid interpolation parameters: {0}")] + InterpolationError(String), +} + +/// Stores the complete PDF grid data, including all subgrids and flavor information. +#[derive(Debug, Serialize, Deserialize)] +pub struct GridArray { + /// An array of particle flavor IDs (PIDs). + pub pids: Array1, + /// A collection of `SubGrid` instances that make up the full grid. + pub subgrids: Vec, +} + +impl GridArray { + /// Creates a new `GridArray` from a vector of `SubgridData`. + /// + /// # Arguments + /// + /// * `subgrid_data` - A vector of `SubgridData` parsed from the PDF data file. + /// * `pids` - A vector of particle flavor IDs. + pub fn new(subgrid_data: Vec, pids: Vec) -> Self { + let nflav = pids.len(); + let subgrids = subgrid_data + .into_iter() + .map(|data| { + SubGrid::new( + data.nucleons, + data.alphas, + data.kts, + data.xs, + data.q2s, + nflav, + data.grid_data, + ) + }) + .collect(); + + Self { + pids: Array1::from_vec(pids), + subgrids, + } + } + + /// Gets the PDF value at a specific knot point in the grid. + /// + /// # Arguments + /// + /// * `nucleon_idx` - The index of the nucleon. + /// * `alpha_idx` - The index of the alpha_s value. + /// * `kt_idx` - The index of the `kT` value. + /// * `x_idx` - The index of the `x` value. + /// * `q2_idx` - The index of the `q2` value. + /// * `flavor_id` - The particle flavor ID. + /// * `subgrid_idx` - The index of the subgrid. + /// + /// # Returns + /// + /// The PDF value `f64` at the specified grid point. + /// + /// # Panics + /// + /// Panics if the `flavor_id` is invalid. + #[allow(clippy::too_many_arguments)] + pub fn xf_from_index( + &self, + nucleon_idx: usize, + alpha_idx: usize, + kt_idx: usize, + x_idx: usize, + q2_idx: usize, + flavor_id: i32, + subgrid_idx: usize, + ) -> f64 { + let pid_idx = self.pid_index(flavor_id).expect("Invalid flavor ID"); + self.subgrids[subgrid_idx].grid[[nucleon_idx, alpha_idx, pid_idx, kt_idx, x_idx, q2_idx]] + } + + /// Finds the index of the subgrid that contains the given point. + /// + /// # Arguments + /// + /// * `points` - A slice of coordinates for the point. + /// + /// # Returns + /// + /// An `Option` containing the index of the subgrid if found, otherwise `None`. + pub fn find_subgrid(&self, points: &[f64]) -> Option { + self.subgrids + .iter() + .position(|sg| sg.contains_point(points)) + .or_else(|| { + self.subgrids + .iter() + .enumerate() + .min_by(|(_, a), (_, b)| { + a.distance_to_point(points) + .partial_cmp(&b.distance_to_point(points)) + .unwrap() + }) + .map(|(idx, _)| idx) + }) + } + + /// Gets the index corresponding to a given flavor ID. + fn pid_index(&self, flavor_id: i32) -> Option { + let normalize_pid = |pid| if pid == 0 { 21 } else { pid }; + let normalized_pids = normalize_pid(flavor_id); + self.pids + .iter() + .position(|&pid| normalize_pid(pid) == normalized_pids) + } + + /// Gets the overall parameter ranges across all subgrids. + /// + /// This method calculates the minimum and maximum values for the nucleon numbers `A`, + /// the AlphaS values `as`, the momentum fraction `x` and the energy scale `q2` across + /// all subgrids to determine the global parameter space. + /// + /// # Returns + /// + /// A `RangeParameters` struct containing the global parameter ranges. + pub fn global_ranges(&self) -> RangeParameters { + fn global_range(subgrids: &[SubGrid], extractor: F) -> ParamRange + where + F: Fn(&SubGrid) -> &ParamRange, + { + let min = subgrids + .iter() + .map(|sg| extractor(sg).min) + .fold(f64::INFINITY, f64::min); + let max = subgrids + .iter() + .map(|sg| extractor(sg).max) + .fold(f64::NEG_INFINITY, f64::max); + ParamRange::new(min, max) + } + + RangeParameters::new( + global_range(&self.subgrids, |sg| &sg.nucleons_range), + global_range(&self.subgrids, |sg| &sg.alphas_range), + global_range(&self.subgrids, |sg| &sg.kt_range), + global_range(&self.subgrids, |sg| &sg.x_range), + global_range(&self.subgrids, |sg| &sg.q2_range), + ) + } +} + +/// Defines the methods for handling negative or small PDF values. +#[repr(C)] +#[derive(Debug, Clone)] +pub enum ForcePositive { + /// If the calculated PDF value is negative, it is forced to 0. + ClipNegative, + /// If the calculated PDF value is less than 1e-10, it is set to 1e-10. + ClipSmall, + /// No clipping is done, value is returned as it is. + NoClipping, +} + +/// The main PDF grid interface, providing high-level methods for interpolation. +pub struct GridPDF { + /// The metadata associated with the PDF set. + info: MetaData, + /// The underlying grid data stored in a `GridArray`. + pub knot_array: GridArray, + /// A nested vector of interpolators for each subgrid and flavor. + interpolators: Vec>>, + /// Calculator for the running of alpha_s. + alphas: AlphaS, + /// Clip the values to positive definite numbers if negatives. + pub force_positive: Option, +} + +impl GridPDF { + /// Creates a new `GridPDF` instance. + /// + /// # Arguments + /// + /// * `info` - The `MetaData` for the PDF set. + /// * `knot_array` - The `GridArray` containing the grid data. + pub fn new(info: MetaData, knot_array: GridArray) -> Self { + let interpolators = Self::build_interpolators(&info, &knot_array); + let alphas = AlphaS::from_metadata(&info).expect("Failed to create AlphaS calculator"); + + Self { + info, + knot_array, + interpolators, + alphas, + force_positive: None, + } + } + + /// Sets the method for handling negative or small PDF values. + /// + /// # Arguments + /// + /// * `flag` - The `ForcePositive` enum variant specifying the clipping method. + pub fn set_force_positive(&mut self, flag: ForcePositive) { + self.force_positive = Some(flag); + } + + /// Applies the configured clipping method to a given PDF value. + /// + /// # Arguments + /// + /// * `value` - The PDF value to which the clipping policy is applied. + /// + /// # Returns + /// + /// The clipped PDF value, according to the policy set by `set_force_positive`. + pub fn apply_force_positive(&self, value: f64) -> f64 { + match &self.force_positive { + Some(ForcePositive::ClipNegative) => value.max(0.0), + Some(ForcePositive::ClipSmall) => value.max(1e-10), + Some(ForcePositive::NoClipping) => value, + _ => value, + } + } + + /// Builds the interpolators for all subgrids and flavors. + fn build_interpolators( + info: &MetaData, + knot_array: &GridArray, + ) -> Vec>> { + knot_array + .subgrids + .iter() + .map(|subgrid| { + (0..knot_array.pids.len()) + .map(|pid_idx| { + InterpolatorFactory::create( + info.interpolator_type.to_owned(), + subgrid, + pid_idx, + ) + }) + .collect() + }) + .collect() + } + + /// Interpolates the PDF value for `(nucleons, alphas, x, q2)` and a given flavor. + /// + /// # Arguments + /// + /// * `flavor_id` - The particle flavor ID. + /// * `points` - A slice containing the collection of points to interpolate on. + /// + /// # Returns + /// + /// A `Result` containing the interpolated PDF value or an `Error`. + pub fn xfxq2(&self, flavor_id: i32, points: &[f64]) -> Result { + let subgrid_idx = self.knot_array.find_subgrid(points).ok_or_else(|| { + let (x, q2) = self.get_x_q2(points); + Error::SubgridNotFound { x, q2 } + })?; + + let pid_idx = match self.knot_array.pid_index(flavor_id) { + Some(idx) => idx, + None => return Ok(0.0), + }; + + let use_log = matches!( + self.info.interpolator_type, + InterpolatorType::LogBilinear + | InterpolatorType::LogBicubic + | InterpolatorType::LogTricubic + | InterpolatorType::LogChebyshev + ); + + self.interpolators[subgrid_idx][pid_idx] + .interpolate_point( + &points + .iter() + .map(|&p| if use_log { p.ln() } else { p }) + .collect::>(), + ) + .map_err(|e| Error::InterpolationError(e.to_string())) + .map(|result| self.apply_force_positive(result)) + } + + /// Interpolates PDF values for multiple points in parallel. + /// + /// # Arguments + /// + /// * `flavors` - A vector of flavor IDs. + /// * `slice_points` - A slice containing the collection of knots to interpolate on. + /// A knot is a collection of points containing `(nucleon, alphas, x, Q2)`. + /// + /// # Returns + /// + /// A 2D array of interpolated PDF values with shape `[flavors, N_knots]`. + pub fn xfxq2s(&self, flavors: Vec, slice_points: &[&[f64]]) -> Array2 { + let grid_shape = [flavors.len(), slice_points.len()]; + let flatten_len = grid_shape.iter().product(); + + let data: Vec = (0..flatten_len) + .map(|idx| { + let num_cols = slice_points.len(); + let (fl_idx, s_idx) = (idx / num_cols, idx % num_cols); + self.xfxq2(flavors[fl_idx], slice_points[s_idx]).unwrap() + }) + .collect(); + + Array2::from_shape_vec(grid_shape, data).unwrap() + } + + /// Interpolates PDF values for multiple points in parallel using Chebyshev batch interpolation. + /// + /// # Arguments + /// + /// * `flavor_id` - The flavor ID. + /// * `points` - A slice containing the collection of knots to interpolate on. + /// A knot is a collection of points containing `(nucleon, alphas, x, Q2)`. + /// + /// # Returns + /// + /// A `Vec` of interpolated PDF values. + pub fn xfxq2_cheby_batch(&self, flavor_id: i32, points: &[&[f64]]) -> Result, Error> { + if points.is_empty() { + return Ok(Vec::new()); + } + + let pid_idx = match self.knot_array.pid_index(flavor_id) { + Some(idx) => idx, + None => return Ok(vec![0.0; points.len()]), + }; + + if !matches!(self.info.interpolator_type, InterpolatorType::LogChebyshev) { + return Err(Error::InterpolationError( + "xfxq2_cheby_batch only supports LogChebyshev interpolator".to_string(), + )); + } + + let mut subgrid_groups: HashMap> = HashMap::new(); + for (i, point) in points.iter().enumerate() { + let subgrid_idx = self.knot_array.find_subgrid(point).ok_or_else(|| { + let (x, q2) = self.get_x_q2(point); + Error::SubgridNotFound { x, q2 } + })?; + + subgrid_groups + .entry(subgrid_idx) + .or_default() + .push((i, *point)); + } + + let mut all_results: Vec<(usize, f64)> = Vec::new(); + + for (subgrid_idx, group) in subgrid_groups { + let subgrid = &self.knot_array.subgrids[subgrid_idx]; + + let (indices, group_points): (Vec<_>, Vec<_>) = group.into_iter().unzip(); + + let log_points: Vec> = group_points + .iter() + .map(|p| p.iter().map(|&v| v.ln()).collect::>()) + .collect(); + + let batch_interpolator = + InterpolatorFactory::create_batch_interpolator(subgrid, pid_idx) + .map_err(Error::InterpolationError)?; + + let results = batch_interpolator + .interpolate(log_points) + .map_err(|e| Error::InterpolationError(e.to_string()))?; + + for (original_index, result) in indices.into_iter().zip(results) { + all_results.push((original_index, result)); + } + } + + // sort the results according to the original index + all_results.sort_by_key(|&(i, _)| i); + let final_results = all_results + .into_iter() + .map(|(_, r)| self.apply_force_positive(r)) + .collect(); + + Ok(final_results) + } + + /// Get the values of the momentum fraction `x` and momentum scale `Q2`. + /// + /// # Arguments + /// + /// * `points` - A slice where the last two elements are `x` and `q2`. + /// + /// # Returns + /// + /// A tuple containing the `x` and `q2` values. + pub fn get_x_q2(&self, points: &[f64]) -> (f64, f64) { + match points { + [.., x, q2] => (*x, *q2), + _ => panic!("The inputs must at least be x and Q2."), + } + } + + /// Gets the alpha_s value at a given `Q²`. + /// + /// # Arguments + /// + /// * `q2` - The energy scale squared `q2`. + /// + /// # Returns + /// + /// The interpolated alpha_s value. + pub fn alphas_q2(&self, q2: f64) -> f64 { + self.alphas.alphas_q2(q2) + } + + /// Returns a reference to the PDF metadata. + pub fn metadata(&self) -> &MetaData { + &self.info + } + + /// Gets the global parameter ranges for the entire PDF set. + pub fn param_ranges(&self) -> RangeParameters { + self.knot_array.global_ranges() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::metadata::{InterpolatorType, MetaData, MetaDataV1, SetType}; + + fn mock_subgrid_data_2d() -> SubgridData { + SubgridData { + nucleons: vec![1.0], + alphas: vec![0.118], + kts: vec![0.0], + xs: vec![0.1, 0.2], + q2s: vec![1.0, 2.0], + grid_data: vec![1.0, 2.0, 3.0, 4.0], + } + } + + fn mock_metadata() -> MetaData { + MetaData::new_v1(MetaDataV1 { + set_desc: "Test".to_string(), + set_index: 0, + num_members: 1, + x_min: 1e-9, + x_max: 1.0, + q_min: 1.0, + q_max: 1000.0, + flavors: vec![21], + format: "test".to_string(), + alphas_q_values: vec![], + alphas_vals: vec![], + polarised: false, + set_type: SetType::SpaceLike, + interpolator_type: InterpolatorType::Bilinear, + error_type: "".to_string(), + hadron_pid: 2212, + git_version: "".to_string(), + code_version: "".to_string(), + flavor_scheme: "variable".to_string(), + order_qcd: 2, + alphas_order_qcd: 2, + m_w: 80.4, + m_z: 91.2, + m_up: 0.0, + m_down: 0.0, + m_strange: 0.0, + m_charm: 1.4, + m_bottom: 4.75, + m_top: 173.0, + alphas_type: "analytic".to_string(), + number_flavors: 5, + }) + } + + fn mock_grid_pdf() -> GridPDF { + let grid_array = GridArray::new(vec![mock_subgrid_data_2d()], vec![21]); + GridPDF::new(mock_metadata(), grid_array) + } + + #[test] + fn test_grid_array_creation() { + let subgrid_data = vec![SubgridData { + nucleons: vec![1.0], + alphas: vec![0.118], + kts: vec![0.0], + xs: vec![1.0, 2.0, 3.0], + q2s: vec![4.0, 5.0], + grid_data: vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + }]; + let flavors = vec![21, 22]; + let grid_array = GridArray::new(subgrid_data, flavors); + + assert_eq!(grid_array.subgrids[0].grid.shape(), &[1, 1, 2, 1, 3, 2]); + assert!(grid_array.find_subgrid(&[1.5, 4.5]).is_some()); + } + + #[test] + fn test_pid_lookup() { + let grid_array = GridArray::new(vec![mock_subgrid_data_2d()], vec![21]); + assert!(grid_array.find_subgrid(&[0.15, 1.5]).is_some()); + assert!(grid_array.find_subgrid(&[0.15, 1.5]).is_some()); + } + + #[test] + fn test_xf_from_index() { + let grid_array = GridArray::new(vec![mock_subgrid_data_2d()], vec![21]); + let val = grid_array.xf_from_index(0, 0, 0, 0, 0, 21, 0); + assert_eq!(val, 1.0); + let val2 = grid_array.xf_from_index(0, 0, 0, 1, 1, 21, 0); + assert_eq!(val2, 4.0); + } + + #[test] + fn test_global_ranges() { + let grid_array = GridArray::new(vec![mock_subgrid_data_2d()], vec![21]); + let ranges = grid_array.global_ranges(); + assert_eq!(ranges.x.min, 0.1); + assert_eq!(ranges.x.max, 0.2); + assert_eq!(ranges.q2.min, 1.0); + assert_eq!(ranges.q2.max, 2.0); + } + + #[test] + fn test_grid_pdf_new() { + let pdf = mock_grid_pdf(); + assert_eq!(pdf.metadata().set_index, 0); + } + + #[test] + fn test_grid_pdf_interpolation() { + let pdf = mock_grid_pdf(); + let result = pdf.xfxq2(21, &[0.15, 1.5]).unwrap(); + assert!((result - 2.5).abs() < 1e-10); + } + + #[test] + fn test_grid_pdf_unknown_pid_returns_zero() { + let pdf = mock_grid_pdf(); + let result = pdf.xfxq2(999, &[0.15, 1.5]).unwrap(); + assert_eq!(result, 0.0); + } + + #[test] + fn test_apply_force_positive() { + let mut pdf = mock_grid_pdf(); + assert_eq!(pdf.apply_force_positive(-1.0), -1.0); + + pdf.set_force_positive(ForcePositive::ClipNegative); + assert_eq!(pdf.apply_force_positive(-1.0), 0.0); + assert_eq!(pdf.apply_force_positive(5.0), 5.0); + + pdf.set_force_positive(ForcePositive::ClipSmall); + assert_eq!(pdf.apply_force_positive(1e-15), 1e-10); + assert_eq!(pdf.apply_force_positive(5.0), 5.0); + + pdf.set_force_positive(ForcePositive::NoClipping); + assert_eq!(pdf.apply_force_positive(-1.0), -1.0); + } + + #[test] + fn test_grid_pdf_alphas_q2() { + let pdf = mock_grid_pdf(); + let as_q2 = pdf.alphas_q2(100.0); + assert!(as_q2 > 0.0); + } + + #[test] + fn test_grid_pdf_param_ranges() { + let pdf = mock_grid_pdf(); + let ranges = pdf.param_ranges(); + assert_eq!(ranges.x.min, 0.1); + assert_eq!(ranges.q2.max, 2.0); + } + + #[test] + fn test_grid_pdf_get_x_q2() { + let pdf = mock_grid_pdf(); + assert_eq!(pdf.get_x_q2(&[0.15, 1.5]), (0.15, 1.5)); + assert_eq!(pdf.get_x_q2(&[5.0, 0.3, 0.15, 1.5]), (0.15, 1.5)); + } + + #[test] + fn test_grid_pdf_xfxq2s() { + let pdf = mock_grid_pdf(); + let result = pdf.xfxq2s(vec![21], &[&[0.15, 1.5]]); + assert_eq!(result.shape(), &[1, 1]); + assert!((result[[0, 0]] - 2.5).abs() < 1e-10); + } + + #[test] + fn test_grid_pdf_metadata() { + let pdf = mock_grid_pdf(); + assert_eq!(pdf.metadata().flavors, vec![21]); + } +} diff --git a/neopdf_legacy/src/interpolator.rs b/neopdf_legacy/src/interpolator.rs new file mode 100644 index 0000000..323dd09 --- /dev/null +++ b/neopdf_legacy/src/interpolator.rs @@ -0,0 +1,853 @@ +//! This module contains the dynamic interpolation traits, InterpolatorFactory, and dynamic +//! dispatch logic for PDF grids. +//! +//! # Contents +//! +//! - [`DynInterpolator`]: Trait for dynamic, multi-dimensional interpolation. +//! - [`InterpolatorFactory`]: Factory for constructing interpolators for SubGrid. +//! +//! # Note +//! +//! Interpolation strategies are defined in `strategy.rs`. +//! The [`SubGrid`] struct is defined in `subgrid.rs`. + +use ndarray::{s, OwnedRepr}; +use ninterp::data::{InterpData2D, InterpData3D}; +use ninterp::error::InterpolateError; +use ninterp::interpolator::{ + Extrapolate, Interp2D, Interp2DOwned, Interp3D, Interp3DOwned, InterpND, InterpNDOwned, +}; +use ninterp::prelude::*; +use ninterp::strategy::traits::{Strategy2D, Strategy3D, StrategyND}; +use ninterp::strategy::Linear; + +use super::metadata::InterpolatorType; +use super::strategy::{ + BilinearInterpolation, LogBicubicInterpolation, LogBilinearInterpolation, + LogChebyshevBatchInterpolation, LogChebyshevInterpolation, LogTricubicInterpolation, +}; +use super::subgrid::SubGrid; + +/// Represents the dimensionality and structure of interpolation needed. +/// +/// This enum is used to select the appropriate interpolation strategy based on the +/// dimensions of the PDF grid data. +#[derive(Debug, Clone, Copy)] +pub enum InterpolationConfig { + /// 2D interpolation, typically in `x` (momentum fraction) and `Q²` (energy scale). + TwoD, + /// 3D interpolation, including a dimension for varying nucleon numbers `A`, + /// in addition to `x` and `Q²`. + ThreeDNucleons, + /// 3D interpolation, including a dimension for varying `alpha_s` values, + /// in addition to `x` and `Q²`. + ThreeDAlphas, + /// 3D interpolation, including a dimension for varying `kT` values, + /// in addition to `x` and `Q²`. + ThreeDKt, + /// 4D interpolation, covering nucleon numbers `A`, `alpha_s`, `x`, and `Q²`. + FourDNucleonsAlphas, + /// 4D interpolation, covering nucleon numbers `A`, kT, `x`, and `Q²`. + FourDNucleonsKt, + /// 4D interpolation, covering `alpha_s`, kT, `x`, and `Q²`. + FourDAlphasKt, + /// 5D interpolation, covering nucleon numbers `A`, `alpha_s`, `kT`, `x`, and `Q²`. + FiveD, +} + +impl InterpolationConfig { + /// Determines the interpolation configuration from the number of nucleons and alpha_s values. + /// + /// # Panics + /// + /// Panics if the combination of `n_nucleons` and `n_alphas` is not supported. + pub fn from_dimensions(n_nucleons: usize, n_alphas: usize, n_kts: usize) -> Self { + match (n_nucleons > 1, n_alphas > 1, n_kts > 1) { + (false, false, false) => Self::TwoD, + (true, false, false) => Self::ThreeDNucleons, + (false, true, false) => Self::ThreeDAlphas, + (false, false, true) => Self::ThreeDKt, + (true, true, false) => Self::FourDNucleonsAlphas, + (true, false, true) => Self::FourDNucleonsKt, + (false, true, true) => Self::FourDAlphasKt, + (true, true, true) => Self::FiveD, + } + } +} + +/// A trait for dynamic interpolation across different dimensions. +pub trait DynInterpolator: Send + Sync { + fn interpolate_point(&self, point: &[f64]) -> Result; +} + +// Implement `DynInterpolator` for 2D interpolators. +impl DynInterpolator for Interp2DOwned +where + S: Strategy2D> + 'static + Clone + Send + Sync, +{ + fn interpolate_point(&self, point: &[f64]) -> Result { + let [x, y] = point + .try_into() + .map_err(|_| InterpolateError::Other("Expected 2D point".to_string()))?; + self.interpolate(&[x, y]) + } +} + +// Implement `DynInterpolator` for 3D interpolators. +impl DynInterpolator for Interp3DOwned +where + S: Strategy3D> + 'static + Clone + Send + Sync, +{ + fn interpolate_point(&self, point: &[f64]) -> Result { + let [x, y, z] = point + .try_into() + .map_err(|_| InterpolateError::Other("Expected 3D point".to_string()))?; + self.interpolate(&[x, y, z]) + } +} + +// Implement `DynInterpolator` for N-dimensional interpolators. +impl DynInterpolator for InterpNDOwned +where + S: StrategyND> + 'static + Clone + Send + Sync, +{ + fn interpolate_point(&self, point: &[f64]) -> Result { + self.interpolate(point) + } +} + +/// An enum to dispatch batch interpolation to the correct Chebyshev interpolator. +pub enum BatchInterpolator { + Chebyshev2D( + LogChebyshevBatchInterpolation<2>, + InterpData2D>, + ), + Chebyshev3D( + LogChebyshevBatchInterpolation<3>, + InterpData3D>, + ), +} + +impl BatchInterpolator { + /// Interpolates a batch of points. + pub fn interpolate(&self, points: Vec>) -> Result, InterpolateError> { + match self { + BatchInterpolator::Chebyshev2D(strategy, data) => { + let points_2d: Vec<[f64; 2]> = points + .into_iter() + .map(|p| p.try_into().expect("Invalid point dimension for 2D")) + .collect(); + strategy.interpolate(data, &points_2d) + } + BatchInterpolator::Chebyshev3D(strategy, data) => { + let points_3d: Vec<[f64; 3]> = points + .into_iter() + .map(|p| p.try_into().expect("Invalid point dimension for 3D")) + .collect(); + strategy.interpolate(data, &points_3d) + } + } + } +} + +/// Factory for creating dynamic interpolators based on interpolation type and grid dimensions. +pub struct InterpolatorFactory; + +impl InterpolatorFactory { + pub fn create( + interp_type: InterpolatorType, + subgrid: &SubGrid, + pid_index: usize, + ) -> Box { + match subgrid.interpolation_config() { + InterpolationConfig::TwoD => Self::interpolator_xfxq2(interp_type, subgrid, pid_index), + InterpolationConfig::ThreeDNucleons => { + Self::interpolator_xfxq2_nucleons(interp_type, subgrid, pid_index) + } + InterpolationConfig::ThreeDAlphas => { + Self::interpolator_xfxq2_alphas(interp_type, subgrid, pid_index) + } + InterpolationConfig::ThreeDKt => { + Self::interpolator_xfxq2_kts(interp_type, subgrid, pid_index) + } + InterpolationConfig::FourDNucleonsAlphas => { + Self::interpolator_xfxq2_nucleons_alphas(interp_type, subgrid, pid_index) + } + InterpolationConfig::FourDNucleonsKt => { + Self::interpolator_xfxq2_nucleons_kts(interp_type, subgrid, pid_index) + } + InterpolationConfig::FourDAlphasKt => { + Self::interpolator_xfxq2_alphas_kts(interp_type, subgrid, pid_index) + } + InterpolationConfig::FiveD => { + Self::interpolator_xfxq2_5dim(interp_type, subgrid, pid_index) + } + } + } + + fn interpolator_xfxq2( + interp_type: InterpolatorType, + subgrid: &SubGrid, + pid_index: usize, + ) -> Box { + let grid_slice = subgrid.grid_slice(pid_index).to_owned(); + + match interp_type { + InterpolatorType::Bilinear => Box::new( + Interp2D::new( + subgrid.xs.to_owned(), + subgrid.q2s.to_owned(), + grid_slice, + BilinearInterpolation, + Extrapolate::Clamp, + ) + .expect("Failed to create 2D interpolator"), + ), + InterpolatorType::LogBilinear => Box::new( + Interp2D::new( + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + grid_slice, + LogBilinearInterpolation, + Extrapolate::Clamp, + ) + .expect("Failed to create 2D interpolator"), + ), + InterpolatorType::LogBicubic => Box::new( + Interp2D::new( + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + grid_slice, + LogBicubicInterpolation::default(), + Extrapolate::Clamp, + ) + .expect("Failed to create 2D interpolator"), + ), + InterpolatorType::LogChebyshev => Box::new( + Interp2D::new( + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + grid_slice, + LogChebyshevInterpolation::<2>::default(), + Extrapolate::Clamp, + ) + .expect("Failed to create 2D interpolator"), + ), + _ => panic!("Unsupported 2D interpolator: {:?}", interp_type), + } + } + + fn interpolator_xfxq2_nucleons( + interp_type: InterpolatorType, + subgrid: &SubGrid, + pid_index: usize, + ) -> Box { + let grid_data = subgrid + .grid + .slice(s![.., 0, pid_index, 0, .., ..]) + .to_owned(); + let reshaped_data = grid_data + .into_shape_with_order((subgrid.nucleons.len(), subgrid.xs.len(), subgrid.q2s.len())) + .expect("Failed to reshape 3D data"); + + match interp_type { + InterpolatorType::LogTricubic => Box::new( + Interp3D::new( + subgrid.nucleons.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + reshaped_data, + LogTricubicInterpolation, + Extrapolate::Clamp, + ) + .expect("Failed to create 3D interpolator"), + ), + InterpolatorType::LogChebyshev => Box::new( + Interp3D::new( + subgrid.nucleons.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + reshaped_data, + LogChebyshevInterpolation::<3>::default(), + Extrapolate::Clamp, + ) + .expect("Failed to create 3D interpolator"), + ), + _ => panic!("Unsupported 3D interpolator: {:?}", interp_type), + } + } + + fn interpolator_xfxq2_alphas( + interp_type: InterpolatorType, + subgrid: &SubGrid, + pid_index: usize, + ) -> Box { + let grid_data = subgrid + .grid + .slice(s![0, .., pid_index, 0, .., ..]) + .to_owned(); + let reshaped_data = grid_data + .into_shape_with_order((subgrid.alphas.len(), subgrid.xs.len(), subgrid.q2s.len())) + .expect("Failed to reshape 3D data"); + + match interp_type { + InterpolatorType::LogTricubic => Box::new( + Interp3D::new( + subgrid.alphas.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + reshaped_data, + LogTricubicInterpolation, + Extrapolate::Clamp, + ) + .expect("Failed to create 3D interpolator"), + ), + InterpolatorType::LogChebyshev => Box::new( + Interp3D::new( + subgrid.alphas.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + reshaped_data, + LogChebyshevInterpolation::<3>::default(), + Extrapolate::Clamp, + ) + .expect("Failed to create 3D interpolator"), + ), + _ => panic!("Unsupported 3D interpolator: {:?}", interp_type), + } + } + + fn interpolator_xfxq2_kts( + interp_type: InterpolatorType, + subgrid: &SubGrid, + pid_index: usize, + ) -> Box { + let grid_data = subgrid + .grid + .slice(s![0, 0, pid_index, .., .., ..]) + .to_owned(); + let reshaped_data = grid_data + .into_shape_with_order((subgrid.kts.len(), subgrid.xs.len(), subgrid.q2s.len())) + .expect("Failed to reshape 3D data"); + + match interp_type { + InterpolatorType::LogTricubic => Box::new( + Interp3D::new( + subgrid.kts.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + reshaped_data, + LogTricubicInterpolation, + Extrapolate::Clamp, + ) + .expect("Failed to create 3D interpolator"), + ), + InterpolatorType::LogChebyshev => Box::new( + Interp3D::new( + subgrid.kts.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + reshaped_data, + LogChebyshevInterpolation::<3>::default(), + Extrapolate::Clamp, + ) + .expect("Failed to create 3D interpolator"), + ), + _ => panic!("Unsupported 3D interpolator: {:?}", interp_type), + } + } + + fn interpolator_xfxq2_nucleons_alphas( + interp_type: InterpolatorType, + subgrid: &SubGrid, + pid_index: usize, + ) -> Box { + let grid_data = subgrid + .grid + .slice(s![.., .., pid_index, 0, .., ..]) + .to_owned(); + let coords = vec![ + subgrid.nucleons.to_owned(), + subgrid.alphas.to_owned(), + subgrid.xs.to_owned(), + subgrid.q2s.to_owned(), + ]; + let reshaped_data = grid_data + .into_shape_with_order(( + subgrid.nucleons.len(), + subgrid.alphas.len(), + subgrid.xs.len(), + subgrid.q2s.len(), + )) + .expect("Failed to reshape 4D data"); + + match interp_type { + InterpolatorType::InterpNDLinear => Box::new( + InterpND::new(coords, reshaped_data.into_dyn(), Linear, Extrapolate::Clamp) + .expect("Failed to create 4D interpolator"), + ), + _ => panic!("Unsupported 4D interpolator: {:?}", interp_type), + } + } + + fn interpolator_xfxq2_nucleons_kts( + interp_type: InterpolatorType, + subgrid: &SubGrid, + pid_index: usize, + ) -> Box { + let grid_data = subgrid + .grid + .slice(s![.., 0, pid_index, .., .., ..]) + .to_owned(); + let coords = vec![ + subgrid.nucleons.mapv(f64::ln), + subgrid.kts.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + ]; + let reshaped_data = grid_data + .into_shape_with_order(( + subgrid.nucleons.len(), + subgrid.kts.len(), + subgrid.xs.len(), + subgrid.q2s.len(), + )) + .expect("Failed to reshape 4D data"); + + match interp_type { + InterpolatorType::InterpNDLinear => Box::new( + InterpND::new(coords, reshaped_data.into_dyn(), Linear, Extrapolate::Clamp) + .expect("Failed to create 4D interpolator"), + ), + _ => panic!("Unsupported 4D interpolator: {:?}", interp_type), + } + } + + fn interpolator_xfxq2_alphas_kts( + interp_type: InterpolatorType, + subgrid: &SubGrid, + pid_index: usize, + ) -> Box { + let grid_data = subgrid + .grid + .slice(s![0, .., pid_index, .., .., ..]) + .to_owned(); + let coords = vec![ + subgrid.alphas.mapv(f64::ln), + subgrid.kts.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + ]; + let reshaped_data = grid_data + .into_shape_with_order(( + subgrid.alphas.len(), + subgrid.kts.len(), + subgrid.xs.len(), + subgrid.q2s.len(), + )) + .expect("Failed to reshape 4D data"); + + match interp_type { + InterpolatorType::InterpNDLinear => Box::new( + InterpND::new(coords, reshaped_data.into_dyn(), Linear, Extrapolate::Clamp) + .expect("Failed to create 4D interpolator"), + ), + _ => panic!("Unsupported 4D interpolator: {:?}", interp_type), + } + } + + fn interpolator_xfxq2_5dim( + interp_type: InterpolatorType, + subgrid: &SubGrid, + pid_index: usize, + ) -> Box { + let grid_data = subgrid + .grid + .slice(s![.., .., pid_index, .., .., ..]) + .to_owned(); + let coords = vec![ + subgrid.nucleons.mapv(f64::ln), + subgrid.alphas.mapv(f64::ln), + subgrid.kts.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + ]; + let reshaped_data = grid_data + .into_shape_with_order(( + subgrid.nucleons.len(), + subgrid.alphas.len(), + subgrid.kts.len(), + subgrid.xs.len(), + subgrid.q2s.len(), + )) + .expect("Failed to reshape 5D data"); + + match interp_type { + InterpolatorType::InterpNDLinear => Box::new( + InterpND::new(coords, reshaped_data.into_dyn(), Linear, Extrapolate::Clamp) + .expect("Failed to create 5D interpolator"), + ), + _ => panic!("Unsupported 5D interpolator: {:?}", interp_type), + } + } + + pub fn create_batch_interpolator( + subgrid: &SubGrid, + pid_idx: usize, + ) -> Result { + match subgrid.interpolation_config() { + InterpolationConfig::TwoD => { + let mut strategy = LogChebyshevBatchInterpolation::<2>::default(); + let grid_slice = subgrid.grid_slice(pid_idx).to_owned(); + + let data = InterpData2D::new( + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + grid_slice, + ) + .map_err(|e| e.to_string())?; + strategy.init(&data).map_err(|e| e.to_string())?; + + Ok(BatchInterpolator::Chebyshev2D(strategy, data)) + } + InterpolationConfig::ThreeDNucleons => { + let mut strategy = LogChebyshevBatchInterpolation::<3>::default(); + let grid_data = subgrid.grid.slice(s![.., 0, pid_idx, 0, .., ..]).to_owned(); + + let reshaped_data = grid_data + .into_shape_with_order(( + subgrid.nucleons.len(), + subgrid.xs.len(), + subgrid.q2s.len(), + )) + .expect("Failed to reshape 3D data"); + + let data = InterpData3D::new( + subgrid.nucleons.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + reshaped_data, + ) + .map_err(|e| e.to_string())?; + strategy.init(&data).map_err(|e| e.to_string())?; + + Ok(BatchInterpolator::Chebyshev3D(strategy, data)) + } + InterpolationConfig::ThreeDAlphas => { + let mut strategy = LogChebyshevBatchInterpolation::<3>::default(); + let grid_data = subgrid.grid.slice(s![0, .., pid_idx, 0, .., ..]).to_owned(); + + let reshaped_data = grid_data + .into_shape_with_order(( + subgrid.alphas.len(), + subgrid.xs.len(), + subgrid.q2s.len(), + )) + .expect("Failed to reshape 3D data"); + + let data = InterpData3D::new( + subgrid.alphas.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + reshaped_data, + ) + .map_err(|e| e.to_string())?; + strategy.init(&data).map_err(|e| e.to_string())?; + + Ok(BatchInterpolator::Chebyshev3D(strategy, data)) + } + InterpolationConfig::ThreeDKt => { + let mut strategy = LogChebyshevBatchInterpolation::<3>::default(); + let grid_data = subgrid.grid.slice(s![0, 0, pid_idx, .., .., ..]).to_owned(); + + let reshaped_data = grid_data + .into_shape_with_order((subgrid.kts.len(), subgrid.xs.len(), subgrid.q2s.len())) + .expect("Failed to reshape 3D data"); + + let data = InterpData3D::new( + subgrid.kts.mapv(f64::ln), + subgrid.xs.mapv(f64::ln), + subgrid.q2s.mapv(f64::ln), + reshaped_data, + ) + .map_err(|e| e.to_string())?; + strategy.init(&data).map_err(|e| e.to_string())?; + + Ok(BatchInterpolator::Chebyshev3D(strategy, data)) + } + _ => Err("Unsupported dimension for batch interpolation".to_string()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::subgrid::SubGrid; + + const MAXDIFF: f64 = 1e-15; + + fn mock_subgrid_2d() -> SubGrid { + let xs = vec![0.1, 0.2]; + let q2s = vec![1.0, 2.0]; + let grid_data = vec![1.0, 2.0, 3.0, 4.0]; + SubGrid::new(vec![1.0], vec![0.118], vec![0.0], xs, q2s, 1, grid_data) + } + + fn mock_subgrid_3d_nucleons() -> SubGrid { + let nucleons = vec![1.0, 2.0, 3.0, 4.0]; + let xs = vec![0.1, 0.2, 0.3, 0.4]; + let q2s = vec![1.0, 2.0, 3.0, 4.0]; + let grid_data = (1..=64).map(|v| v as f64).collect(); + SubGrid::new(nucleons, vec![0.118], vec![0.0], xs, q2s, 1, grid_data) + } + + fn mock_subgrid_3d_alphas() -> SubGrid { + let alphas = vec![0.118, 0.120, 0.122, 0.124]; + let xs = vec![0.1, 0.2, 0.3, 0.4]; + let q2s = vec![1.0, 2.0, 3.0, 4.0]; + let grid_data = (1..=64).map(|v| v as f64).collect(); + SubGrid::new(vec![1.0], alphas, vec![0.0], xs, q2s, 1, grid_data) + } + + fn mock_subgrid_3d_kts() -> SubGrid { + let kts = vec![0.5, 1.0, 1.5, 2.0]; + let xs = vec![0.1, 0.2, 0.3, 0.4]; + let q2s = vec![1.0, 2.0, 3.0, 4.0]; + let grid_data = (1..=64).map(|v| v as f64).collect(); + SubGrid::new(vec![1.0], vec![0.118], kts, xs, q2s, 1, grid_data) + } + + fn mock_subgrid_4d_nucleons_alphas() -> SubGrid { + let nucleons = vec![1.0, 2.0]; + let alphas = vec![0.118, 0.120]; + let xs = vec![0.1, 0.2]; + let q2s = vec![1.0, 2.0]; + let grid_data = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + ]; + SubGrid::new(nucleons, alphas, vec![0.0], xs, q2s, 1, grid_data) + } + + #[test] + fn test_interpolation_config() { + assert!(matches!( + InterpolationConfig::from_dimensions(1, 1, 1), + InterpolationConfig::TwoD + )); + assert!(matches!( + InterpolationConfig::from_dimensions(2, 1, 1), + InterpolationConfig::ThreeDNucleons + )); + assert!(matches!( + InterpolationConfig::from_dimensions(1, 2, 1), + InterpolationConfig::ThreeDAlphas + )); + assert!(matches!( + InterpolationConfig::from_dimensions(1, 1, 2), + InterpolationConfig::ThreeDKt + )); + assert!(matches!( + InterpolationConfig::from_dimensions(2, 2, 1), + InterpolationConfig::FourDNucleonsAlphas + )); + assert!(matches!( + InterpolationConfig::from_dimensions(2, 1, 2), + InterpolationConfig::FourDNucleonsKt + )); + assert!(matches!( + InterpolationConfig::from_dimensions(1, 2, 2), + InterpolationConfig::FourDAlphasKt + )); + assert!(matches!( + InterpolationConfig::from_dimensions(2, 2, 2), + InterpolationConfig::FiveD + )); + } + + #[test] + fn test_2d_bilinear_interpolation() { + let subgrid = mock_subgrid_2d(); + let interpolator = InterpolatorFactory::create(InterpolatorType::Bilinear, &subgrid, 0); + let result = interpolator.interpolate_point(&[0.15, 1.5]).unwrap(); + assert!((result - 2.5).abs() < MAXDIFF); + } + + #[test] + fn test_3d_nucleons_interpolation() { + let subgrid = mock_subgrid_3d_nucleons(); + let interpolator = InterpolatorFactory::create(InterpolatorType::LogTricubic, &subgrid, 0); + let result = interpolator + .interpolate_point(&[2.0f64.ln(), 0.2f64.ln(), 2.0f64.ln()]) + .unwrap(); + assert!((result - 22.0).abs() < MAXDIFF); + } + + #[test] + fn test_3d_alphas_interpolation() { + let subgrid = mock_subgrid_3d_alphas(); + let interpolator = InterpolatorFactory::create(InterpolatorType::LogTricubic, &subgrid, 0); + let result = interpolator + .interpolate_point(&[0.120f64.ln(), 0.2f64.ln(), 2.0f64.ln()]) + .unwrap(); + assert!((result - 22.0).abs() < MAXDIFF); + } + + #[test] + fn test_3d_kts_interpolation() { + let subgrid = mock_subgrid_3d_kts(); + let interpolator = InterpolatorFactory::create(InterpolatorType::LogTricubic, &subgrid, 0); + let result = interpolator + .interpolate_point(&[1.0f64.ln(), 0.2f64.ln(), 2.0f64.ln()]) + .unwrap(); + assert!((result - 22.0).abs() < MAXDIFF); + } + + #[test] + fn test_4d_nucleons_alphas_interpolation() { + let subgrid = mock_subgrid_4d_nucleons_alphas(); + let interpolator = + InterpolatorFactory::create(InterpolatorType::InterpNDLinear, &subgrid, 0); + let result = interpolator + .interpolate_point(&[1.5, 0.119, 0.15, 1.5]) + .unwrap(); + assert!((result - 8.5).abs() < MAXDIFF); + } + + #[test] + #[should_panic] + fn test_unsupported_interpolator() { + let subgrid = mock_subgrid_2d(); + InterpolatorFactory::create(InterpolatorType::LogTricubic, &subgrid, 0); + } + + fn mock_subgrid_4d_nucleons_kt() -> SubGrid { + let nucleons = vec![1.0, 2.0]; + let kts = vec![0.5, 1.0]; + let xs = vec![0.1, 0.2]; + let q2s = vec![1.0, 2.0]; + let grid_data: Vec = (1..=16).map(|v| v as f64).collect(); + SubGrid::new(nucleons, vec![0.118], kts, xs, q2s, 1, grid_data) + } + + fn mock_subgrid_4d_alphas_kt() -> SubGrid { + let alphas = vec![0.118, 0.120]; + let kts = vec![0.5, 1.0]; + let xs = vec![0.1, 0.2]; + let q2s = vec![1.0, 2.0]; + let grid_data: Vec = (1..=16).map(|v| v as f64).collect(); + SubGrid::new(vec![1.0], alphas, kts, xs, q2s, 1, grid_data) + } + + fn mock_subgrid_5d() -> SubGrid { + let nucleons = vec![1.0, 2.0]; + let alphas = vec![0.118, 0.120]; + let kts = vec![0.5, 1.0]; + let xs = vec![0.1, 0.2]; + let q2s = vec![1.0, 2.0]; + let grid_data: Vec = (1..=32).map(|v| v as f64).collect(); + SubGrid::new(nucleons, alphas, kts, xs, q2s, 1, grid_data) + } + + #[test] + fn test_2d_log_bilinear_interpolation() { + let subgrid = mock_subgrid_2d(); + let interpolator = InterpolatorFactory::create(InterpolatorType::LogBilinear, &subgrid, 0); + let result = interpolator + .interpolate_point(&[0.15f64.ln(), 1.5f64.ln()]) + .unwrap(); + assert!(result.is_finite()); + } + + #[test] + fn test_2d_log_bicubic_interpolation() { + let xs = vec![0.1, 0.2, 0.3, 0.4]; + let q2s = vec![1.0, 2.0, 3.0, 4.0]; + let grid_data: Vec = (1..=16).map(|v| v as f64).collect(); + let subgrid = SubGrid::new(vec![1.0], vec![0.118], vec![0.0], xs, q2s, 1, grid_data); + let interpolator = InterpolatorFactory::create(InterpolatorType::LogBicubic, &subgrid, 0); + let result = interpolator + .interpolate_point(&[0.2f64.ln(), 2.0f64.ln()]) + .unwrap(); + assert!(result.is_finite()); + } + + #[test] + fn test_4d_nucleons_kt_interpolation() { + let subgrid = mock_subgrid_4d_nucleons_kt(); + let interpolator = + InterpolatorFactory::create(InterpolatorType::InterpNDLinear, &subgrid, 0); + let result = interpolator + .interpolate_point(&[1.5, 0.75, 0.15, 1.5]) + .unwrap(); + assert!(result.is_finite()); + } + + #[test] + fn test_4d_alphas_kt_interpolation() { + let subgrid = mock_subgrid_4d_alphas_kt(); + let interpolator = + InterpolatorFactory::create(InterpolatorType::InterpNDLinear, &subgrid, 0); + let result = interpolator + .interpolate_point(&[0.119, 0.75, 0.15, 1.5]) + .unwrap(); + assert!(result.is_finite()); + } + + #[test] + fn test_5d_interpolation() { + let subgrid = mock_subgrid_5d(); + let interpolator = + InterpolatorFactory::create(InterpolatorType::InterpNDLinear, &subgrid, 0); + let result = interpolator + .interpolate_point(&[1.5, 0.119, 0.75, 0.15, 1.5]) + .unwrap(); + assert!(result.is_finite()); + } + + #[test] + fn test_create_batch_interpolator_2d() { + let subgrid = mock_subgrid_2d(); + let batch = InterpolatorFactory::create_batch_interpolator(&subgrid, 0).unwrap(); + let points = vec![ + vec![0.15f64.ln(), 1.5f64.ln()], + vec![0.12f64.ln(), 1.2f64.ln()], + ]; + let results = batch.interpolate(points).unwrap(); + assert_eq!(results.len(), 2); + assert!(results.iter().all(|r| r.is_finite())); + } + + #[test] + fn test_create_batch_interpolator_3d_nucleons() { + let subgrid = mock_subgrid_3d_nucleons(); + let batch = InterpolatorFactory::create_batch_interpolator(&subgrid, 0).unwrap(); + let pts = vec![vec![2.0f64.ln(), 0.2f64.ln(), 2.0f64.ln()]]; + let results = batch.interpolate(pts).unwrap(); + assert_eq!(results.len(), 1); + } + + #[test] + fn test_create_batch_interpolator_3d_alphas() { + let subgrid = mock_subgrid_3d_alphas(); + let batch = InterpolatorFactory::create_batch_interpolator(&subgrid, 0).unwrap(); + let pts = vec![vec![0.120f64.ln(), 0.2f64.ln(), 2.0f64.ln()]]; + let results = batch.interpolate(pts).unwrap(); + assert_eq!(results.len(), 1); + } + + #[test] + fn test_create_batch_interpolator_3d_kt() { + let subgrid = mock_subgrid_3d_kts(); + let batch = InterpolatorFactory::create_batch_interpolator(&subgrid, 0).unwrap(); + let pts = vec![vec![1.0f64.ln(), 0.2f64.ln(), 2.0f64.ln()]]; + let results = batch.interpolate(pts).unwrap(); + assert_eq!(results.len(), 1); + } + + #[test] + fn test_create_batch_interpolator_unsupported() { + let subgrid = mock_subgrid_4d_nucleons_alphas(); + assert!(InterpolatorFactory::create_batch_interpolator(&subgrid, 0).is_err()); + } +} diff --git a/neopdf_legacy/src/lib.rs b/neopdf_legacy/src/lib.rs new file mode 100644 index 0000000..34c8526 --- /dev/null +++ b/neopdf_legacy/src/lib.rs @@ -0,0 +1,59 @@ +//! # NeoPDF Library +//! +//! NeoPDF is a modern, fast, and reliable Rust library for reading, managing, and interpolating +//! both collinear and transverse momentum Parton Distribution Functions ([TMD] PDFs) from both +//! the LHAPDF, TMDlib, and NeoPDF formats. +//! +//! ## Main Features +//! +//! - **Unified PDF Set Interface:** Load, access, and interpolate PDF sets from both LHAPDF and +//! NeoPDF formats using a consistent API. +//! - **High-Performance Interpolation:** Provides multi-dimensional interpolation (including +//! log-bicubic, log-tricubic, and more) for PDF values, supporting advanced use cases in +//! high-energy physics. +//! - **Flexible Metadata Handling:** Rich metadata structures for describing PDF sets, including +//! support for an arbitrary type of hadrons. +//! - **Conversion and Compression:** Tools to convert LHAPDF sets to NeoPDF format and to combine +//! multiple nuclear PDF sets into a single file with explicit A dependence. +//! - **Efficient Storage:** Compressed storage and random access to large PDF sets using LZ4 and +//! bincode serialization. +//! +//! ## Module Overview +//! +//! - [`converter`]: Utilities for converting and combining PDF sets. +//! - [`gridpdf`]: Core grid data structures and high-level PDF grid interface. +//! - [`interpolator`]: Dynamic interpolation traits and factories for PDF grids. +//! - [`manage`]: Management utilities for PDF set installation, download, and path resolution. +//! - [`metadata`]: Metadata structures and types for describing PDF sets. +//! - [`parser`]: Parsing utilities for reading and interpreting PDF set data files. +//! - [`pdf`]: High-level interface for working with PDF sets and interpolation. +//! - [`strategy`]: Interpolation strategy implementations (bilinear, log-bicubic, etc.). +//! - [`subgrid`]: Subgrid data structures and parameter range logic. +//! - [`utils`]: Utility functions for interpolation and grid operations. +//! - [`writer`]: Utilities for serializing, compressing, and accessing PDF grid data. +//! +//! ## Example Usage +//! +//! ```rust +//! use neopdf_legacy::pdf::PDF; +//! +//! // Load a PDF member from a set (LHAPDF or NeoPDF format) +//! let pdf = PDF::load("NNPDF40_nnlo_as_01180", 0); +//! let xf = pdf.xfxq2(21, &[0.01, 100.0]); +//! println!("xf = {}", xf); +//! ``` +//! +//! See module-level documentation for more details and advanced usage. + +pub mod alphas; +pub mod converter; +pub mod gridpdf; +pub mod interpolator; +pub mod manage; +pub mod metadata; +pub mod parser; +pub mod pdf; +pub mod strategy; +pub mod subgrid; +pub mod utils; +pub mod writer; diff --git a/neopdf_legacy/src/manage.rs b/neopdf_legacy/src/manage.rs new file mode 100644 index 0000000..4a13108 --- /dev/null +++ b/neopdf_legacy/src/manage.rs @@ -0,0 +1,206 @@ +//! This module provides management utilities for PDF set installation, download, and path resolution. +//! +//! It defines types and methods for ensuring that PDF sets are available locally, downloading them if +//! necessary, and handling different PDF set formats (LHAPDF, NeoPDF). +use flate2::read::GzDecoder; +use indicatif::{ProgressBar, ProgressStyle}; +use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::io::Read; +use std::path::{Path, PathBuf}; +use tar::Archive; + +/// TODO +#[derive(Debug, Deserialize, Serialize)] +pub enum PdfSetFormat { + /// TODO + Lhapdf, + /// TODO + Neopdf, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ManageData { + neopdf_path: PathBuf, + set_name: String, + pdfset_path: PathBuf, + pdfset_format: PdfSetFormat, +} + +impl ManageData { + pub fn new(set_name: &str, format: PdfSetFormat) -> Self { + let data_path = Self::get_data_path(); + let xpdf_path = data_path.join(set_name); + + let manager = Self { + neopdf_path: data_path, + set_name: set_name.to_string(), + pdfset_path: xpdf_path, + pdfset_format: format, + }; + manager.ensure_pdf_installed().unwrap(); + + manager + } + + pub fn get_data_path() -> PathBuf { + // Check for NEOPDF_DATA_PATH environment variable first + if let Ok(neopdf_data_path) = std::env::var("NEOPDF_DATA_PATH") { + let neopdf_dir = PathBuf::from(neopdf_data_path); + + if !neopdf_dir.exists() { + std::fs::create_dir_all(&neopdf_dir).unwrap(); + } + + return neopdf_dir; + } + + // Falls back to the XDG data directory if the env. variable is not set. + // TODO: Make this more robust and not platform-dependent + let home = std::env::var("HOME") + .map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::NotFound, + "HOME environment variable not found", + ) + }) + .unwrap(); + let data_dir = PathBuf::from(home).join(".local").join("share"); + let neopdf_dir = data_dir.join("neopdf"); + + if !neopdf_dir.exists() { + std::fs::create_dir_all(&neopdf_dir).unwrap(); + } + + neopdf_dir + } + + /// Download the PDF set and extract it into the designated path. + /// The download happens in memory so no `*.tar.*` is written. + pub fn download_pdf(&self) -> Result<(), Box> { + let url = format!( + "https://lhapdfsets.web.cern.ch/current/{}.tar.gz", + self.set_name + ); + println!("Downloading PDF set from: {}", url); + + let response = ureq::get(&url).call()?; + + let total_size = response + .header("content-length") + .and_then(|s| s.parse::().ok()) + .unwrap_or(0); + + let pb = ProgressBar::new(total_size); + pb.set_style(ProgressStyle::default_bar() + .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")? + .progress_chars("=>-")); + + let mut response_bytes = Vec::new(); + let mut decorated_response = pb.wrap_read(response.into_reader()); + decorated_response.read_to_end(&mut response_bytes)?; + + let tar = GzDecoder::new(&response_bytes[..]); + let mut archive = Archive::new(tar); + + archive.unpack(&self.neopdf_path)?; + + Ok(()) + } + + /// Check that the PDF set is installed in the correct path. + pub fn is_pdf_installed(&self) -> bool { + match self.pdfset_format { + PdfSetFormat::Neopdf => self.pdfset_path.exists() && self.pdfset_path.is_file(), + _ => self.pdfset_path.exists() && self.pdfset_path.is_dir(), + } + } + + /// Ensure that the PDF set is installed, otherwise download it. + pub fn ensure_pdf_installed(&self) -> Result<(), Box> { + if self.is_pdf_installed() { + return Ok(()); + } + + println!("PDF set '{}' not found, downloading...", self.set_name); + self.download_pdf() + } + + /// Get the name of the PDF set. + pub fn set_name(&self) -> &str { + &self.set_name + } + + /// Get the path where PDF sets are stored. + pub fn data_path(&self) -> &Path { + &self.neopdf_path + } + + /// Get the full path to this specific PDF set. + pub fn set_path(&self) -> &Path { + &self.pdfset_path + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[test] + fn test_get_data_path_from_env() { + let tmp = TempDir::new().unwrap(); + let tmp_path = tmp.path().to_str().unwrap().to_string(); + // SAFETY: single-threaded test; env var is restored after test scope + unsafe { std::env::set_var("NEOPDF_DATA_PATH", &tmp_path) }; + let path = ManageData::get_data_path(); + unsafe { std::env::remove_var("NEOPDF_DATA_PATH") }; + assert_eq!(path, tmp.path()); + } + + #[test] + fn test_get_data_path_creates_dir() { + let tmp = TempDir::new().unwrap(); + let new_dir = tmp.path().join("neopdf_test_subdir"); + unsafe { std::env::set_var("NEOPDF_DATA_PATH", new_dir.to_str().unwrap()) }; + let path = ManageData::get_data_path(); + unsafe { std::env::remove_var("NEOPDF_DATA_PATH") }; + assert!(path.exists()); + } + + #[test] + fn test_manage_data_with_existing_neopdf_file() { + let tmp = TempDir::new().unwrap(); + let set_name = "fake_set.neopdf.lz4"; + let fake_file = tmp.path().join(set_name); + fs::write(&fake_file, b"fake").unwrap(); + + unsafe { std::env::set_var("NEOPDF_DATA_PATH", tmp.path().to_str().unwrap()) }; + let mgr = ManageData::new(set_name, PdfSetFormat::Neopdf); + unsafe { std::env::remove_var("NEOPDF_DATA_PATH") }; + + assert_eq!(mgr.set_name(), set_name); + assert_eq!(mgr.set_path(), fake_file.as_path()); + assert!(mgr.data_path().exists()); + assert!(mgr.is_pdf_installed()); + } + + #[test] + fn test_is_pdf_installed_missing() { + let tmp = TempDir::new().unwrap(); + let set_name = "nonexistent_set.neopdf.lz4"; + + unsafe { std::env::set_var("NEOPDF_DATA_PATH", tmp.path().to_str().unwrap()) }; + // Manually build the struct to avoid triggering a download + let mgr = ManageData { + neopdf_path: tmp.path().to_path_buf(), + set_name: set_name.to_string(), + pdfset_path: tmp.path().join(set_name), + pdfset_format: PdfSetFormat::Neopdf, + }; + unsafe { std::env::remove_var("NEOPDF_DATA_PATH") }; + + assert!(!mgr.is_pdf_installed()); + } +} diff --git a/neopdf_legacy/src/metadata.rs b/neopdf_legacy/src/metadata.rs new file mode 100644 index 0000000..c1c24ca --- /dev/null +++ b/neopdf_legacy/src/metadata.rs @@ -0,0 +1,320 @@ +//! This module defines metadata structures and types for describing PDF sets. +//! +//! It includes the `MetaData` struct (deserialized from .info files), PDF set +//! and interpolator type enums, and related utilities for handling PDF set information. +use serde::{Deserialize, Deserializer, Serialize}; +use std::fmt; +use std::ops::{Deref, DerefMut}; + +/// Represents the type of PDF set. +#[repr(C)] +#[derive(Clone, Debug, Deserialize, Serialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum SetType { + #[default] + SpaceLike, + TimeLike, +} + +/// Represents the type of interpolator used for the PDF. +/// WARNING: When adding elements, always append to the end!!! +#[repr(C)] +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub enum InterpolatorType { + Bilinear, + LogBilinear, + #[default] + LogBicubic, + LogTricubic, + InterpNDLinear, + LogChebyshev, +} + +/// Represents the information block of a given set. +/// +/// In order to support LHAPDF formats, the fields here are very much influenced by the +/// LHAPDF `.info` file. This struct is generally deserialized from a YAML-like format. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct MetaDataV1 { + /// Description of the PDF set. + #[serde(rename = "SetDesc")] + pub set_desc: String, + /// Index of the PDF set. + #[serde(rename = "SetIndex")] + pub set_index: u32, + /// Number of members in the PDF set (e.g., for error analysis). + #[serde(rename = "NumMembers")] + pub num_members: u32, + /// Minimum x-value for which the PDF is valid. + #[serde(rename = "XMin")] + pub x_min: f64, + /// Maximum x-value for which the PDF is valid. + #[serde(rename = "XMax")] + pub x_max: f64, + /// Minimum Q-value (energy scale) for which the PDF is valid. + #[serde(rename = "QMin")] + pub q_min: f64, + /// Maximum Q-value (energy scale) for which the PDF is valid. + #[serde(rename = "QMax")] + pub q_max: f64, + /// List of particle data group (PDG) IDs for the flavors included in the PDF. + #[serde(rename = "Flavors")] + pub flavors: Vec, + /// Format of the PDF data. + #[serde(rename = "Format")] + pub format: String, + /// AlphaS Q values (non-squared) for interpolation. + #[serde(rename = "AlphaS_Qs", default)] + pub alphas_q_values: Vec, + /// AlphaS values for interpolation. + #[serde(rename = "AlphaS_Vals", default)] + pub alphas_vals: Vec, + /// Polarisation of the hadrons. + #[serde(rename = "Polarized", default)] + pub polarised: bool, + /// Type of the hadrons. + #[serde(rename = "SetType", default)] + pub set_type: SetType, + /// Type of interpolator used for the PDF (e.g., "LogBicubic"). + #[serde(rename = "InterpolatorType", default)] + pub interpolator_type: InterpolatorType, + /// The error type representation of the PDF. + #[serde(rename = "ErrorType", default)] + pub error_type: String, + /// The hadron PID value representation of the PDF. + #[serde(rename = "Particle", default)] + pub hadron_pid: i32, + /// The git version of the code that generated the PDF. + #[serde(rename = "GitVersion", default)] + pub git_version: String, + /// The code version (CARGO_PKG_VERSION) that generated the PDF. + #[serde(rename = "CodeVersion", default)] + pub code_version: String, + /// Scheme for the treatment of heavy flavors + #[serde(rename = "FlavorScheme", default)] + pub flavor_scheme: String, + /// Number of QCD loops in the calculation of PDF evolution. + #[serde(rename = "OrderQCD", default)] + pub order_qcd: u32, + /// Number of QCD loops in the calculation of `alpha_s`. + #[serde(rename = "AlphaS_OrderQCD", default)] + pub alphas_order_qcd: u32, + /// Value of the W boson mass. + #[serde(rename = "MW", default)] + pub m_w: f64, + /// Value of the Z boson mass. + #[serde(rename = "MZ", default)] + pub m_z: f64, + /// Value of the Up quark mass. + #[serde(rename = "MUp", default)] + pub m_up: f64, + /// Value of the Down quark mass. + #[serde(rename = "MDown", default)] + pub m_down: f64, + /// Value of the Strange quark mass. + #[serde(rename = "MStrange", default)] + pub m_strange: f64, + /// Value of the Charm quark mass. + #[serde(rename = "MCharm", default)] + pub m_charm: f64, + /// Value of the Bottom quark mass. + #[serde(rename = "MBottom", default)] + pub m_bottom: f64, + /// Value of the Top quark mass. + #[serde(rename = "MTop", default)] + pub m_top: f64, + /// Type of strong coupling computations. + #[serde(rename = "AlphaS_Type", default)] + pub alphas_type: String, + /// Number of active PDF flavors. + #[serde(rename = "NumFlavors", default)] + pub number_flavors: u32, +} + +/// Version-aware metadata wrapper that handles serialization compatibility. +#[derive(Clone, Debug, Serialize)] +#[serde(untagged)] +pub enum MetaData { + V1(MetaDataV1), +} + +impl MetaData { + /// Creates a new instance of V1 `MetaData`. + pub fn new_v1(data: MetaDataV1) -> Self { + Self::V1(data) + } + + /// Gets the current version as the latest available version. + pub fn current_v1(data: MetaDataV1) -> Self { + Self::V1(data) + } + + /// Gets the underlying data as the latest version. + pub fn as_latest(&self) -> MetaDataV1 { + match self { + MetaData::V1(data) => data.clone(), + } + } +} + +impl Deref for MetaData { + type Target = MetaDataV1; + + fn deref(&self) -> &Self::Target { + match self { + MetaData::V1(data) => data, + } + } +} + +impl DerefMut for MetaData { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + MetaData::V1(data) => data, + } + } +} + +impl<'de> Deserialize<'de> for MetaData { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let v1 = MetaDataV1::deserialize(deserializer)?; + + Ok(MetaData::V1(v1)) + } +} + +impl fmt::Display for MetaData { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "Set Description: {}", self.set_desc)?; + writeln!(f, "Set Index: {}", self.set_index)?; + writeln!(f, "Number of Members: {}", self.num_members)?; + writeln!(f, "XMin: {}", self.x_min)?; + writeln!(f, "XMax: {}", self.x_max)?; + writeln!(f, "QMin: {}", self.q_min)?; + writeln!(f, "QMax: {}", self.q_max)?; + writeln!(f, "Flavors: {:?}", self.flavors)?; + writeln!(f, "Format: {}", self.format)?; + writeln!(f, "AlphaS Q Values: {:?}", self.alphas_q_values)?; + writeln!(f, "AlphaS Values: {:?}", self.alphas_vals)?; + writeln!(f, "Polarized: {}", self.polarised)?; + writeln!(f, "Set Type: {:?}", self.set_type)?; + writeln!(f, "Interpolator Type: {:?}", self.interpolator_type)?; + writeln!(f, "Error Type: {}", self.error_type)?; + writeln!(f, "Particle: {}", self.hadron_pid)?; + writeln!(f, "Flavor Scheme: {}", self.flavor_scheme)?; + writeln!(f, "Order QCD: {}", self.order_qcd)?; + writeln!(f, "AlphaS Order QCD: {}", self.alphas_order_qcd)?; + writeln!(f, "MW: {}", self.m_w)?; + writeln!(f, "MZ: {}", self.m_z)?; + writeln!(f, "MUp: {}", self.m_up)?; + writeln!(f, "MDown: {}", self.m_down)?; + writeln!(f, "MStrange: {}", self.m_strange)?; + writeln!(f, "MCharm: {}", self.m_charm)?; + writeln!(f, "MBottom: {}", self.m_bottom)?; + writeln!(f, "MTop: {}", self.m_top)?; + writeln!(f, "AlphaS Type: {}", self.alphas_type)?; + writeln!(f, "Number of PDF flavors: {}", self.number_flavors) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_v1() -> MetaDataV1 { + MetaDataV1 { + set_desc: "TestSet".to_string(), + set_index: 42, + num_members: 3, + x_min: 1e-9, + x_max: 1.0, + q_min: 1.0, + q_max: 1000.0, + flavors: vec![21, 1, -1], + format: "lhapdf".to_string(), + alphas_q_values: vec![10.0], + alphas_vals: vec![0.118], + polarised: false, + set_type: SetType::SpaceLike, + interpolator_type: InterpolatorType::LogBicubic, + error_type: "replicas".to_string(), + hadron_pid: 2212, + git_version: "abc".to_string(), + code_version: "0.2.0".to_string(), + flavor_scheme: "variable".to_string(), + order_qcd: 2, + alphas_order_qcd: 2, + m_w: 80.4, + m_z: 91.2, + m_up: 0.0, + m_down: 0.0, + m_strange: 0.0, + m_charm: 1.4, + m_bottom: 4.75, + m_top: 173.0, + alphas_type: "analytic".to_string(), + number_flavors: 5, + } + } + + #[test] + fn test_new_v1_and_deref() { + let meta = MetaData::new_v1(make_v1()); + assert_eq!(meta.set_index, 42); + assert_eq!(meta.flavors, vec![21, 1, -1]); + } + + #[test] + fn test_current_v1() { + let meta = MetaData::current_v1(make_v1()); + assert_eq!(meta.set_desc, "TestSet"); + } + + #[test] + fn test_as_latest() { + let meta = MetaData::new_v1(make_v1()); + let v1 = meta.as_latest(); + assert_eq!(v1.num_members, 3); + assert_eq!(v1.q_max, 1000.0); + } + + #[test] + fn test_deref_mut() { + let mut meta = MetaData::new_v1(make_v1()); + meta.set_index = 99; + assert_eq!(meta.set_index, 99); + } + + #[test] + fn test_display() { + let meta = MetaData::new_v1(make_v1()); + let s = format!("{meta}"); + assert!(s.contains("TestSet")); + assert!(s.contains("Set Index: 42")); + assert!(s.contains("Number of Members: 3")); + } + + #[test] + fn test_set_type_default() { + assert!(matches!(SetType::default(), SetType::SpaceLike)); + } + + #[test] + fn test_interpolator_type_default() { + assert!(matches!( + InterpolatorType::default(), + InterpolatorType::LogBicubic + )); + } + + #[test] + fn test_deserialize_roundtrip() { + let meta = MetaData::new_v1(make_v1()); + let yaml = serde_yaml::to_string(&meta).unwrap(); + let restored: MetaData = serde_yaml::from_str(&yaml).unwrap(); + assert_eq!(restored.set_index, 42); + } +} diff --git a/neopdf_legacy/src/parser.rs b/neopdf_legacy/src/parser.rs new file mode 100644 index 0000000..2d94f2b --- /dev/null +++ b/neopdf_legacy/src/parser.rs @@ -0,0 +1,344 @@ +//! This module provides parsing utilities for reading and interpreting PDF set data files. +//! +//! It defines types and methods for loading, parsing, and representing both LHAPDF and NeoPDF +//! set formats, including subgrid data extraction and metadata reading. +use serde::{Deserialize, Serialize}; +use std::fs; +use std::path::{Path, PathBuf}; + +use super::gridpdf::GridArray; +use super::manage::{ManageData, PdfSetFormat}; +use super::metadata::MetaData; +use super::writer::{GridArrayReader, LazyGridArrayIterator}; + +/// Represents the data for a single subgrid within a PDF data file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SubgridData { + pub nucleons: Vec, + pub alphas: Vec, + pub kts: Vec, + pub xs: Vec, + pub q2s: Vec, + pub grid_data: Vec, +} + +/// Represents the parsed data from an LHAPDF `.dat` file. +#[derive(Debug, Serialize, Deserialize)] +pub struct PdfData { + pub subgrid_data: Vec, + pub pids: Vec, + pub alphas_q_values: Option>, + pub alphas_vals: Option>, +} + +/// Manages the loading and parsing of LHAPDF data sets. +/// +/// This struct provides methods to read metadata and member data files +/// for a given LHAPDF set. +pub struct LhapdfSet { + manager: ManageData, + pub info: MetaData, +} + +impl LhapdfSet { + /// Creates a new `LhapdfSet` instance for a given PDF set name. + /// + /// This constructor initializes the data manager and reads the metadata + /// for the specified PDF set. + /// + /// # Arguments + /// + /// * `pdf_name` - The name of the PDF set (e.g., "NNPDF40_nnlo_as_01180"). + pub fn new(pdf_name: &str) -> Self { + let manager = ManageData::new(pdf_name, PdfSetFormat::Lhapdf); + let pdfset_path = manager.set_path(); + let info_path = pdfset_path.join(format!( + "{}.info", + pdfset_path.file_name().unwrap().to_str().unwrap() + )); + let info: MetaData = Self::read_metadata(&info_path).unwrap(); + + Self { manager, info } + } + + /// Reads the metadata and data for a specific member of the PDF set. + /// + /// # Arguments + /// + /// * `member` - The ID of the PDF member to load. + /// + /// # Returns + /// + /// A tuple containing the `MetaData` and `PdfData` for the specified member. + pub fn member(&self, member: usize) -> (MetaData, GridArray) { + let pdfset_path = self.manager.set_path(); + let data_path = pdfset_path.join(format!( + "{}_{:04}.dat", + pdfset_path.file_name().unwrap().to_str().unwrap(), + member + )); + + let pdf_data = Self::read_data(&data_path); + let knot_array = GridArray::new(pdf_data.subgrid_data, pdf_data.pids); + + let mut info = self.info.clone(); + if info.alphas_vals.is_empty() { + if let (Some(vals), Some(q_values)) = (pdf_data.alphas_vals, pdf_data.alphas_q_values) { + if !vals.is_empty() && !q_values.is_empty() { + info.alphas_vals = vals; + info.alphas_q_values = q_values; + } + } + } + (info, knot_array) + } + + /// Reads the metadata and data for all members of the PDF set. + /// + /// # Returns + /// + /// A vector of tuples, where each tuple contains the `MetaData` and `PdfData` + /// for a member of the set. + pub fn members(&self) -> Vec<(MetaData, GridArray)> { + (0..self.info.num_members as usize) + .map(|i| self.member(i)) + .collect() + } + + /// Reads the `.info` file for a PDF set and deserializes it into an `Info` struct. + /// + /// # Arguments + /// + /// * `path` - The path to the `.info` file. + /// + /// # Returns + /// + /// A `Result` containing the `Info` struct if successful, or a `serde_yaml::Error` otherwise. + fn read_metadata(path: &Path) -> Result { + let content = fs::read_to_string(path).unwrap(); + serde_yaml::from_str(&content) + } + + /// Reads an LHAPDF `.dat` file for a PDF set and parses its content. + /// + /// This function extracts x-knots, Q2-knots, flavor IDs, and the grid data + /// from the specified data file. It can handle files with multiple subgrids + /// separated by "---". + /// + /// # Arguments + /// + /// * `path` - The path to the `.dat` file. + /// + /// # Returns + /// + /// A `PdfData` struct containing the parsed subgrid data and flavor IDs. + pub fn read_data(path: &Path) -> PdfData { + let content = fs::read_to_string(path).unwrap(); + let mut subgrid_data = Vec::new(); + let mut flavors = Vec::new(); + let mut alphas_q_values: Option> = None; + let mut alphas_vals: Option> = None; + + let blocks: Vec<&str> = content.split("---").map(|s| s.trim()).collect(); + + // NOTE: support cases in which `AlphaS` grid info are in `.dat` files. + if !blocks.is_empty() { + #[derive(serde::Deserialize)] + struct DatMeta { + #[serde(rename = "AlphaS_Qs", default)] + alphas_q_values: Vec, + #[serde(rename = "AlphaS_Vals", default)] + alphas_vals: Vec, + } + + let metadata_block = blocks[0]; + if let Ok(dat_meta) = serde_yaml::from_str::(metadata_block) { + if !dat_meta.alphas_q_values.is_empty() { + alphas_q_values = Some(dat_meta.alphas_q_values); + } + if !dat_meta.alphas_vals.is_empty() { + alphas_vals = Some(dat_meta.alphas_vals); + } + } + } + + for block in blocks.iter().skip(1) { + if block.is_empty() { + continue; + } + + let mut lines = block.lines(); + + let x_knots_line = lines.next().unwrap(); + let xs: Vec = x_knots_line + .split_whitespace() + .filter_map(|s| s.parse().ok()) + .collect(); + + let q2_knots_line = lines.next().unwrap(); + let q2s: Vec = q2_knots_line + .split_whitespace() + .filter_map(|s| s.parse().ok()) + .map(|q: f64| q * q) + .collect(); + + // Read the flavors (only once from the first subgrid) + if flavors.is_empty() { + let flavors_line = lines.next().unwrap(); + flavors = flavors_line + .split_whitespace() + .filter_map(|s| s.parse().ok()) + .collect(); + } else { + // Skip the flavors line in subsequent subgrids + lines.next(); + } + + let mut grid_data = Vec::new(); + for line in lines { + let values: Vec = line + .split_whitespace() + .filter_map(|s| s.parse().ok()) + .collect(); + grid_data.extend(values); + } + + // NOTE: given that there isn't really a proper way to extract the + // following values from LHAPDF, their defaults are set to zeros. + let nucleons: Vec = vec![0.0]; + let alphas: Vec = vec![0.0]; + let kts: Vec = vec![0.0]; + + subgrid_data.push(SubgridData { + nucleons, + alphas, + kts, + xs, + q2s, + grid_data, + }); + } + + PdfData { + subgrid_data, + pids: flavors, + alphas_q_values, + alphas_vals, + } + } +} + +/// Manages the loading and parsing of NeoPDF sets. +pub struct NeopdfSet { + pub info: MetaData, + grid_reader: GridArrayReader, + setpath: PathBuf, +} + +impl NeopdfSet { + /// TODO + pub fn new(pdf_name: &str) -> Self { + let manager = ManageData::new(pdf_name, PdfSetFormat::Neopdf); + let neopdf_setpath = manager.set_path(); + let grid_readers = GridArrayReader::from_file(neopdf_setpath).unwrap(); + let metadata_info = grid_readers.metadata().as_ref().clone(); + + Self { + info: metadata_info, + grid_reader: grid_readers, + setpath: neopdf_setpath.to_path_buf(), + } + } + + /// TODO + pub fn member(&self, member: usize) -> (MetaData, GridArray) { + let load_grid = self.grid_reader.load_grid(member).unwrap(); + (self.info.clone(), load_grid.grid) + } + + /// TODO + pub fn into_lazy_iterators(&self) -> LazyGridArrayIterator { + LazyGridArrayIterator::from_file(&self.setpath).unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use tempfile::NamedTempFile; + + #[test] + fn test_read_info() { + let yaml_content = r#" + SetDesc: "NNPDF40_nnlo_as_01180" + SetIndex: 4000 + NumMembers: 101 + XMin: 1.0e-9 + XMax: 1.0 + QMin: 1.0 + QMax: 10000.0 + Flavors: [21, 1, 2, 3, 4, 5, -1, -2, -3, -4, -5] + Format: "LHAPDF" + "#; + let mut temp_file = NamedTempFile::new().unwrap(); + write!(temp_file, "{}", yaml_content).unwrap(); + let info = LhapdfSet::read_metadata(temp_file.path()).unwrap(); + + assert_eq!(info.set_desc, "NNPDF40_nnlo_as_01180"); + assert_eq!(info.set_index, 4000); + assert_eq!(info.num_members, 101); + assert_eq!(info.x_min, 1.0e-9); + assert_eq!(info.x_max, 1.0); + assert_eq!(info.q_min, 1.0); + assert_eq!(info.q_max, 10000.0); + assert_eq!(info.flavors, vec![21, 1, 2, 3, 4, 5, -1, -2, -3, -4, -5]); + assert_eq!(info.format, "LHAPDF"); + } + + #[test] + fn test_read_data() { + let data_content = r#" + # Some header + --- + 1.0e-9 1.0e-8 1.0e-7 + 1.0 10.0 100.0 + 21 1 2 + 1.0 2.0 3.0 + 4.0 5.0 6.0 + 7.0 8.0 9.0 + --- + 1.0e-7 1.0e-6 1.0e-5 + 100.0 1000.0 10000.0 + 21 1 2 + 10.0 11.0 12.0 + 13.0 14.0 15.0 + 16.0 17.0 18.0 + "#; + let mut temp_file = NamedTempFile::new().unwrap(); + write!(temp_file, "{}", data_content).unwrap(); + let pdf_data = LhapdfSet::read_data(temp_file.path()); + + assert_eq!(pdf_data.pids, vec![21, 1, 2]); + assert_eq!(pdf_data.subgrid_data.len(), 2); + + // Check the first subgrid + assert_eq!(pdf_data.subgrid_data[0].xs, vec![1.0e-9, 1.0e-8, 1.0e-7]); + assert_eq!(pdf_data.subgrid_data[0].q2s, vec![1.0, 100.0, 10000.0]); // Q values are squared + assert_eq!( + pdf_data.subgrid_data[0].grid_data, + vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] + ); + + // Check the second subgrid + assert_eq!(pdf_data.subgrid_data[1].xs, vec![1.0e-7, 1.0e-6, 1.0e-5]); + assert_eq!( + pdf_data.subgrid_data[1].q2s, + vec![10000.0, 1000000.0, 100000000.0] + ); // Q values are squared + assert_eq!( + pdf_data.subgrid_data[1].grid_data, + vec![10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0] + ); + } +} diff --git a/neopdf_legacy/src/pdf.rs b/neopdf_legacy/src/pdf.rs new file mode 100644 index 0000000..efa5a01 --- /dev/null +++ b/neopdf_legacy/src/pdf.rs @@ -0,0 +1,556 @@ +//! This module provides the high-level interface for working with PDF sets. +//! +//! It defines the [`PDF`] struct, which serves as the main entry point for accessing, +//! interpolating, and retrieving metadata from PDF sets. The module abstracts over different +//! PDF set formats (LHAPDF and NeoPDF) and provides convenient loader functions for both +//! single and multiple PDF members. +//! +//! # Main Features +//! +//! - Unified interface for loading and accessing PDF sets from different formats. +//! - Parallel loading of all PDF members for efficient batch operations. +//! - High-level interpolation methods for PDF values and strong coupling constant (`alpha_s`). +//! - Access to underlying grid data and metadata for advanced use cases. +//! +//! # Key Types +//! +//! - [`PDF`]: Represents a single PDF member, providing methods for interpolation and metadata access. +//! - [`PdfSet`]: Trait for abstracting over different PDF set backends. +//! - Loader functions: [`PDF::load`], [`PDF::load_pdfs`], and internal helpers for batch loading. +//! +//! See the documentation for [`PDF`] for more details on available methods and usage patterns. +use ndarray::{Array1, Array2}; +use rayon::prelude::*; + +use super::gridpdf::{ForcePositive, GridArray, GridPDF}; +use super::metadata::MetaData; +use super::parser::{LhapdfSet, NeopdfSet}; +use super::subgrid::{RangeParameters, SubGrid}; + +/// Trait for abstracting over different PDF set backends (e.g., LHAPDF, NeoPDF). +/// +/// Provides a unified interface for accessing the number of members and retrieving individual +/// members as metadata and grid arrays. +trait PdfSet: Send + Sync { + /// Returns the number of members in the PDF set. + fn num_members(&self) -> usize; + /// Retrieves the metadata and grid array for the specified member index. + fn member(&self, idx: usize) -> (MetaData, GridArray); +} + +impl PdfSet for LhapdfSet { + fn num_members(&self) -> usize { + self.info.num_members as usize + } + fn member(&self, idx: usize) -> (MetaData, GridArray) { + self.member(idx) + } +} + +impl PdfSet for NeopdfSet { + fn num_members(&self) -> usize { + self.info.num_members as usize + } + fn member(&self, idx: usize) -> (MetaData, GridArray) { + self.member(idx) + } +} + +/// Loads a single PDF member from a generic PDF set backend. +/// +/// # Arguments +/// +/// * `set` - The PDF set backend implementing [`PdfSet`]. +/// * `member` - The index of the member to load. +/// +/// # Returns +/// +/// A [`PDF`] instance for the specified member. +fn pdfset_loader(set: T, member: usize) -> PDF { + let (info, knot_array) = set.member(member); + PDF { + grid_pdf: GridPDF::new(info, knot_array), + } +} + +/// Loads all PDF members from a generic PDF set backend in sequential. +/// +/// # Arguments +/// +/// * `set` - The PDF set backend implementing [`PdfSet`]. +/// +/// # Returns +/// +/// A vector of [`PDF`] instances, one for each member in the set. +fn pdfsets_seq_loader(set: T) -> Vec { + (0..set.num_members()) + .map(|idx| { + let (info, knot_array) = set.member(idx); + PDF { + grid_pdf: GridPDF::new(info, knot_array), + } + }) + .collect() +} + +/// Loads all PDF members from a generic PDF set backend in parallel. +/// +/// # Arguments +/// +/// * `set` - The PDF set backend implementing [`PdfSet`]. +/// +/// # Returns +/// +/// A vector of [`PDF`] instances, one for each member in the set. +fn pdfsets_par_loader(set: T) -> Vec { + (0..set.num_members()) + .into_par_iter() + .map(|idx| { + let (info, knot_array) = set.member(idx); + PDF { + grid_pdf: GridPDF::new(info, knot_array), + } + }) + .collect() +} + +/// Represents a Parton Distribution Function (PDF) set. +/// +/// This struct provides a high-level interface for accessing PDF data, +/// including interpolation and metadata retrieval. It encapsulates the +/// `GridPDF` struct, which handles the low-level grid operations. +pub struct PDF { + grid_pdf: GridPDF, +} + +impl PDF { + /// Loads a given member of the PDF set. + /// + /// This function reads the `.info` file and the corresponding `.dat` member file + /// to construct a `GridPDF` object, which is then wrapped in a `PDF` instance. + /// + /// # Arguments + /// + /// * `pdf_name` - The name of the PDF set (e.g., "NNPDF40_nnlo_as_01180"). + /// * `member` - The ID of the PDF member to load (0-indexed). + /// + /// # Returns + /// + /// A `PDF` instance representing the loaded PDF member. + pub fn load(pdf_name: &str, member: usize) -> Self { + if pdf_name.ends_with(".neopdf.lz4") { + pdfset_loader(NeopdfSet::new(pdf_name), member) + } else { + pdfset_loader(LhapdfSet::new(pdf_name), member) + } + } + + /// Loads all members of a PDF set in parallel. + /// + /// This function reads the `.info` file and all `.dat` member files + /// to construct a `Vec`, with each `PDF` instance representing a member + /// of the set. The loading is performed in parallel. + /// + /// # Arguments + /// + /// * `pdf_name` - The name of the PDF set. + /// + /// # Returns + /// + /// A `Vec` where each element is a `PDF` instance for a member of the set. + pub fn load_pdfs(pdf_name: &str) -> Vec { + if pdf_name.ends_with(".neopdf.lz4") { + pdfsets_par_loader(NeopdfSet::new(pdf_name)) + } else { + pdfsets_par_loader(LhapdfSet::new(pdf_name)) + } + } + + /// Loads all members of a PDF set in sequential. + /// + /// This function reads the `.info` file and all `.dat` member files + /// to construct a `Vec`, with each `PDF` instance representing a member + /// of the set. The loading is performed in parallel. + /// + /// # Arguments + /// + /// * `pdf_name` - The name of the PDF set. + /// + /// # Returns + /// + /// A `Vec` where each element is a `PDF` instance for a member of the set. + pub fn load_pdfs_seq(pdf_name: &str) -> Vec { + if pdf_name.ends_with(".neopdf.lz4") { + pdfsets_seq_loader(NeopdfSet::new(pdf_name)) + } else { + pdfsets_seq_loader(LhapdfSet::new(pdf_name)) + } + } + + /// Creates an iterator that loads PDF members lazily. + /// + /// This function is suitable for `.neopdf.lz4` files, which support lazy loading. + /// It returns an iterator that yields `PDF` instances on demand, which is useful + /// for reducing memory consumption when working with large PDF sets. + /// + /// # Arguments + /// + /// * `pdf_name` - The name of the PDF set (must end with `.neopdf.lz4`). + /// + /// # Returns + /// + /// An iterator over `Result>`. + pub fn load_pdfs_lazy( + pdf_name: &str, + ) -> impl Iterator>> { + assert!( + pdf_name.ends_with(".neopdf.lz4"), + "Lazy loading is only supported for .neopdf.lz4 files" + ); + + let iter_lazy = NeopdfSet::new(pdf_name).into_lazy_iterators(); + + iter_lazy.map(|grid_array_with_metadata_result| { + grid_array_with_metadata_result.map(|grid_array_with_metadata| { + let info = (*grid_array_with_metadata.metadata).clone(); + let knot_array = grid_array_with_metadata.grid; + PDF { + grid_pdf: GridPDF::new(info, knot_array), + } + }) + }) + } + + /// Clip the negative values for the `PDF` object. + /// + /// # Arguments + /// + /// * `option` - The method used to clip negative values. + pub fn set_force_positive(&mut self, option: ForcePositive) { + self.grid_pdf.set_force_positive(option); + } + + /// Clip the negative values for all the `PDF` objects. + /// + /// # Arguments + /// + /// * `pdfs` - A `Vec` where each element is a `PDF` instance. + /// * `option` - The method used to clip negative values. + pub fn set_force_positive_members(pdfs: &mut [PDF], option: ForcePositive) { + for pdf in pdfs { + pdf.set_force_positive(option.clone()); + } + } + + /// Returns the clipping method used for a single `PDF` object. + /// + /// # Returns + /// + /// The clipping method given as a `ForcePositive` object. + pub fn is_force_positive(&self) -> &ForcePositive { + self.grid_pdf + .force_positive + .as_ref() + .unwrap_or(&ForcePositive::NoClipping) + } + + /// Interpolates the PDF value (xf) for a given nucleon, alphas, flavor, x, and Q2. + /// + /// Abstraction to the `GridPDF::xfxq2` method. + /// + /// # Arguments + /// + /// * `id` - The flavor ID (PDG ID). + /// * `points` - A slice containing the collection of points to interpolate on. + /// + /// # Returns + /// + /// The interpolated PDF value `xf(nuclone, alphas, flavor, x, Q^2)`. + pub fn xfxq2(&self, pid: i32, points: &[f64]) -> f64 { + self.grid_pdf.xfxq2(pid, points).unwrap() + } + + /// Interpolates the PDF value (xf) for multiple nucleons, alphas, flavors, xs, and Q2s. + /// + /// Abstraction to the `GridPDF::xfxq2s` method. + /// + /// # Arguments + /// + /// * `ids` - A vector of flavor IDs. + /// * `slice_points` - A slice containing the collection of knots to interpolate on. + /// A knot is a collection of points containing `(nucleon, alphas, x, Q2)`. + /// + /// # Returns + /// + /// A 2D array of interpolated PDF values with shape `[flavors, N_knots]`. + pub fn xfxq2s(&self, pids: Vec, slice_points: &[&[f64]]) -> Array2 { + self.grid_pdf.xfxq2s(pids, slice_points) + } + + /// Interpolates the PDF value (xf) for multiple points using Chebyshev batch interpolation. + /// + /// Abstraction to the `GridPDF::xfxq2_cheby_batch` method. + /// + /// # Arguments + /// + /// * `pid` - The flavor ID. + /// * `points` - A slice containing the collection of knots to interpolate on. + /// A knot is a collection of points containing `(nucleon, alphas, x, Q2)`. + /// + /// # Returns + /// + /// A `Vec` of interpolated PDF values. + pub fn xfxq2_cheby_batch(&self, pid: i32, points: &[&[f64]]) -> Vec { + self.grid_pdf.xfxq2_cheby_batch(pid, points).unwrap() + } + + /// Interpolates the strong coupling constant `alpha_s` for a given Q2. + /// + /// Abstraction to the `GridPDF::alphas_q2` method. + /// + /// # Arguments + /// + /// * `q2` - The squared energy scale. + /// + /// # Returns + /// + /// The interpolated `alpha_s` value. + pub fn alphas_q2(&self, q2: f64) -> f64 { + self.grid_pdf.alphas_q2(q2) + } + + /// Returns a reference to the PDF metadata. + /// + /// Abstraction to the `GridPDF::info` method. + /// + /// # Returns + /// + /// A `MetaData` struct containing information about the PDF set. + pub fn metadata(&self) -> &MetaData { + self.grid_pdf.metadata() + } + + /// Returns the number of subgrids in the PDF set. + /// + /// # Returns + /// + /// The number of subgrids. + pub fn num_subgrids(&self) -> usize { + self.grid_pdf.knot_array.subgrids.len() + } + + /// Returns a reference to the subgrid at the given index. + /// + /// # Arguments + /// + /// * `index` - The index of the subgrid. + /// + /// # Returns + /// + /// A reference to the `SubGrid`. + pub fn subgrid(&self, index: usize) -> &SubGrid { + &self.grid_pdf.knot_array.subgrids[index] + } + + /// Returns references to all the subgrid at the given index. + /// + /// # Returns + /// + /// A reference to all the `SubGrid`. + pub fn subgrids(&self) -> &Vec { + &self.grid_pdf.knot_array.subgrids + } + + /// Returns the flavor PIDS of the PDG Grid. + /// + /// # Returns + /// + /// PID representation of the PDF. + pub fn pids(&self) -> &Array1 { + &self.grid_pdf.knot_array.pids + } + + /// Retrieves the ranges for the parameters. + /// + /// Abstraction to the `GridPDF::param_ranges` method. + /// + /// # Returns + /// + /// The minimum and maximum values for the parameters (x, q2, ...). + pub fn param_ranges(&self) -> RangeParameters { + self.grid_pdf.param_ranges() + } + + /// Retrieves the PDF value (xf) at a specific knot point in the grid. + /// + /// Abstraction to the `GridArray::xf_from_index` method. This method does not + /// perform any interpolation. + /// + /// # Arguments + /// + /// * `i_nucleons` - The index of the nucleon. + /// * `i_alphas` - The index of the alpha_s value. + /// * `i_kt` - The index of the `kT` value. + /// * `ix` - The index of the x-value. + /// * `iq2` - The index of the Q2-value. + /// * `id` - The flavor ID. + /// * `subgrid_id` - The ID of the subgrid. + /// + /// # Returns + /// + /// The PDF value at the specified knot. + #[allow(clippy::too_many_arguments)] + pub fn xf_from_index( + &self, + i_nucleons: usize, + i_alphas: usize, + i_kt: usize, + ix: usize, + iq2: usize, + id: i32, + subgrid_id: usize, + ) -> f64 { + self.grid_pdf + .knot_array + .xf_from_index(i_nucleons, i_alphas, i_kt, ix, iq2, id, subgrid_id) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::gridpdf::{GridArray, GridPDF}; + use crate::metadata::{InterpolatorType, MetaData, MetaDataV1, SetType}; + use crate::parser::SubgridData; + + fn mock_metadata() -> MetaData { + MetaData::new_v1(MetaDataV1 { + set_desc: "Test".to_string(), + set_index: 7, + num_members: 1, + x_min: 0.1, + x_max: 0.2, + q_min: 1.0, + q_max: 2.0, + flavors: vec![21], + format: "test".to_string(), + alphas_q_values: vec![], + alphas_vals: vec![], + polarised: false, + set_type: SetType::SpaceLike, + interpolator_type: InterpolatorType::Bilinear, + error_type: "".to_string(), + hadron_pid: 2212, + git_version: "".to_string(), + code_version: "".to_string(), + flavor_scheme: "variable".to_string(), + order_qcd: 2, + alphas_order_qcd: 2, + m_w: 80.4, + m_z: 91.2, + m_up: 0.0, + m_down: 0.0, + m_strange: 0.0, + m_charm: 1.4, + m_bottom: 4.75, + m_top: 173.0, + alphas_type: "analytic".to_string(), + number_flavors: 5, + }) + } + + fn mock_pdf() -> PDF { + let data = vec![SubgridData { + nucleons: vec![1.0], + alphas: vec![0.118], + kts: vec![0.0], + xs: vec![0.1, 0.2], + q2s: vec![1.0, 2.0], + grid_data: vec![1.0, 2.0, 3.0, 4.0], + }]; + PDF { + grid_pdf: GridPDF::new(mock_metadata(), GridArray::new(data, vec![21])), + } + } + + #[test] + fn test_xfxq2() { + let pdf = mock_pdf(); + let val = pdf.xfxq2(21, &[0.15, 1.5]); + assert!((val - 2.5).abs() < 1e-10); + } + + #[test] + fn test_xfxq2s() { + let pdf = mock_pdf(); + let result = pdf.xfxq2s(vec![21], &[&[0.15, 1.5]]); + assert_eq!(result.shape(), &[1, 1]); + } + + #[test] + fn test_alphas_q2() { + let pdf = mock_pdf(); + assert!(pdf.alphas_q2(100.0) > 0.0); + } + + #[test] + fn test_metadata() { + let pdf = mock_pdf(); + assert_eq!(pdf.metadata().set_index, 7); + } + + #[test] + fn test_num_subgrids() { + let pdf = mock_pdf(); + assert_eq!(pdf.num_subgrids(), 1); + } + + #[test] + fn test_subgrid_access() { + let pdf = mock_pdf(); + assert_eq!(pdf.subgrid(0).xs.len(), 2); + assert_eq!(pdf.subgrids().len(), 1); + } + + #[test] + fn test_pids() { + let pdf = mock_pdf(); + assert_eq!(pdf.pids().len(), 1); + assert_eq!(pdf.pids()[0], 21); + } + + #[test] + fn test_param_ranges() { + let pdf = mock_pdf(); + let r = pdf.param_ranges(); + assert_eq!(r.x.min, 0.1); + assert_eq!(r.q2.max, 2.0); + } + + #[test] + fn test_force_positive() { + let mut pdf = mock_pdf(); + assert!(matches!(pdf.is_force_positive(), ForcePositive::NoClipping)); + pdf.set_force_positive(ForcePositive::ClipNegative); + assert!(matches!( + pdf.is_force_positive(), + ForcePositive::ClipNegative + )); + } + + #[test] + fn test_set_force_positive_members() { + let mut pdfs = vec![mock_pdf(), mock_pdf()]; + PDF::set_force_positive_members(&mut pdfs, ForcePositive::ClipSmall); + for pdf in &pdfs { + assert!(matches!(pdf.is_force_positive(), ForcePositive::ClipSmall)); + } + } + + #[test] + fn test_xf_from_index() { + let pdf = mock_pdf(); + let val = pdf.xf_from_index(0, 0, 0, 0, 0, 21, 0); + assert_eq!(val, 1.0); + } +} diff --git a/neopdf_legacy/src/strategy.rs b/neopdf_legacy/src/strategy.rs new file mode 100644 index 0000000..8c517ac --- /dev/null +++ b/neopdf_legacy/src/strategy.rs @@ -0,0 +1,1977 @@ +//! This module defines various interpolation strategies used within the `neopdf` library. +//! +//! It provides implementations for 1D, 2D, and 3D interpolation, including: +//! - `BilinearInterpolation`: Standard bilinear interpolation for 2D data. +//! - `LogBilinearInterpolation`: Bilinear interpolation performed in logarithmic space for both +//! coordinates, suitable for data that exhibits linear behavior in log-log plots. +//! - `LogBicubicInterpolation`: Bicubic interpolation with logarithmic coordinate scaling, +//! providing C1 continuity and higher accuracy for 2D data. +//! - `LogTricubicInterpolation`: Tricubic interpolation with logarithmic coordinate scaling, +//! extending bicubic interpolation to 3D data with C1 continuity. +//! - `AlphaSCubicInterpolation`: A specialized 1D cubic interpolation strategy for alpha_s values, +//! incorporating specific extrapolation rules as defined in LHAPDF. +//! +//! All interpolation strategies are designed to work with `ninterp`'s data structures and traits, +//! ensuring compatibility and extensibility. + +use ndarray::{Array2, Axis, Data, RawDataClone}; +use ninterp::data::{InterpData1D, InterpData2D, InterpData3D}; +use ninterp::error::{InterpolateError, ValidateError}; +use ninterp::strategy::traits::{Strategy1D, Strategy2D, Strategy3D}; +use serde::{Deserialize, Serialize}; +use std::f64::consts::PI; + +use super::utils; + +/// Implements bilinear interpolation for 2D data. +/// +/// This strategy performs linear interpolation sequentially along two dimensions. +/// It is suitable for smooth, continuous 2D datasets where a simple linear +/// approximation between grid points is sufficient. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct BilinearInterpolation; + +impl BilinearInterpolation { + /// Performs linear interpolation between two points. + /// + /// Given two points `(x1, y1)` and `(x2, y2)`, this function calculates the + /// y-value corresponding to a given `x` using linear interpolation. + /// + /// # Arguments + /// + /// * `x1` - The x-coordinate of the first point. + /// * `x2` - The x-coordinate of the second point. + /// * `y1` - The y-coordinate of the first point. + /// * `y2` - The y-coordinate of the second point. + /// * `x` - The x-coordinate at which to interpolate. + /// + /// # Returns + /// + /// The interpolated y-value. + fn linear_interpolate(x1: f64, x2: f64, y1: f64, y2: f64, x: f64) -> f64 { + if x1 == x2 { + return y1; + } + y1 + (y2 - y1) * (x - x1) / (x2 - x1) + } +} + +impl Strategy2D for BilinearInterpolation +where + D: Data + RawDataClone + Clone, +{ + /// Performs bilinear interpolation at a given point. + /// + /// # Arguments + /// + /// * `data` - The interpolation data containing grid coordinates and values. + /// * `point` - A 2-element array `[x, y]` representing the coordinates to interpolate at. + /// + /// # Returns + /// + /// The interpolated value as a `Result`. + fn interpolate( + &self, + data: &InterpData2D, + point: &[f64; 2], + ) -> Result { + let [x, y] = *point; + + let x_coords = data.grid[0].as_slice().unwrap(); + let y_coords = data.grid[1].as_slice().unwrap(); + let values = &data.values; + + let x_idx = utils::find_interval_index(x_coords, x)?; + let y_idx = utils::find_interval_index(y_coords, y)?; + + let x1 = x_coords[x_idx]; + let x2 = x_coords[x_idx + 1]; + let y1 = y_coords[y_idx]; + let y2 = y_coords[y_idx + 1]; + + let q11 = values[[x_idx, y_idx]]; // f(x1, y1) + let q12 = values[[x_idx, y_idx + 1]]; // f(x1, y2) + let q21 = values[[x_idx + 1, y_idx]]; // f(x2, y1) + let q22 = values[[x_idx + 1, y_idx + 1]]; // f(x2, y2) + + let r1 = Self::linear_interpolate(x1, x2, q11, q21, x); + let r2 = Self::linear_interpolate(x1, x2, q12, q22, x); + + let result = Self::linear_interpolate(y1, y2, r1, r2, y); + + Ok(result) + } + + /// Indicates that this strategy does not allow extrapolation. + fn allow_extrapolate(&self) -> bool { + true + } +} + +/// Performs bilinear interpolation in log space. +/// +/// This strategy transforms the input coordinates to their natural logarithms +/// before performing bilinear interpolation, which is suitable for data +/// that is linear in log-log space. It is particularly useful for physical +/// quantities that span several orders of magnitude, such as momentum transfer +/// squared (Q²) or Bjorken x. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct LogBilinearInterpolation; + +impl Strategy2D for LogBilinearInterpolation +where + D: Data + RawDataClone + Clone, +{ + /// Initializes the strategy, performing validation checks. + /// + /// # Arguments + /// + /// * `data` - The interpolation data to validate. + fn init(&mut self, _data: &InterpData2D) -> Result<(), ValidateError> { + Ok(()) + } + + /// Performs log-bilinear interpolation at a given point. + /// + /// The input `point` coordinates are first transformed to log space, + /// then bilinear interpolation is applied. + /// + /// # Arguments + /// + /// * `data` - The interpolation data containing grid coordinates and values. + /// * `point` - A 2-element array `[x, y]` representing the coordinates to interpolate at. + /// + /// # Returns + /// + /// The interpolated value as a `Result`. + fn interpolate( + &self, + data: &InterpData2D, + point: &[f64; 2], + ) -> Result { + let [x, y] = *point; + + let x_coords = data.grid[0].as_slice().unwrap(); + let y_coords = data.grid[1].as_slice().unwrap(); + let values = &data.values; + + let x_idx = utils::find_interval_index(x_coords, x)?; + let y_idx = utils::find_interval_index(y_coords, y)?; + + let x1 = x_coords[x_idx]; + let x2 = x_coords[x_idx + 1]; + let y1 = y_coords[y_idx]; + let y2 = y_coords[y_idx + 1]; + + let q11 = values[[x_idx, y_idx]]; // f(x1, y1) + let q12 = values[[x_idx, y_idx + 1]]; // f(x1, y2) + let q21 = values[[x_idx + 1, y_idx]]; // f(x2, y1) + let q22 = values[[x_idx + 1, y_idx + 1]]; // f(x2, y2) + + let r1 = BilinearInterpolation::linear_interpolate(x1, x2, q11, q21, x); + let r2 = BilinearInterpolation::linear_interpolate(x1, x2, q12, q22, x); + + let result = BilinearInterpolation::linear_interpolate(y1, y2, r1, r2, y); + + Ok(result) + } + + /// Indicates that this strategy does not allow extrapolation. + fn allow_extrapolate(&self) -> bool { + true + } +} + +/// LogBicubic interpolation strategy for PDF-like data. +/// +/// This strategy implements bicubic interpolation with logarithmic coordinate scaling. +/// It is designed for interpolating Parton Distribution Functions (PDFs) where: +/// - x-coordinates (e.g., Bjorken x) are logarithmically spaced. +/// - y-coordinates (e.g., Q² values) are logarithmically spaced. +/// - z-values (PDF values) are interpolated using bicubic splines. +/// +/// Bicubic interpolation uses a 4x4 grid of points around the interpolation point +/// and provides C1 continuity (continuous first derivatives), resulting in a +/// smoother and more accurate interpolation compared to bilinear methods. +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct LogBicubicInterpolation { + coeffs: Vec, +} + +impl LogBicubicInterpolation { + /// Find the interval for bicubic interpolation. + /// + /// This function determines the appropriate interval index `i` within a set of + /// coordinates `coords` such that `coords[i] <= x < coords[i+1]`. For bicubic + /// interpolation, this index `i` is used to select the 4x4 grid of points + /// `[i-1, i, i+1, i+2]` that are relevant for the interpolation. + /// + /// # Arguments + /// + /// * `coords` - A slice of `f64` representing the sorted coordinate values. + /// * `x` - The `f64` value for which to find the interval. + /// + /// # Returns + /// + /// A `Result` containing the `usize` index of the lower bound of the interval + /// if successful, or an `InterpolateError` if `x` is out of bounds. + fn find_bicubic_interval(coords: &[f64], x: f64) -> Result { + // Find the interval [i, i+1] such that coords[i] <= x < coords[i+1] + let i = utils::find_interval_index(coords, x)?; + Ok(i) + } + + /// Cubic interpolation using a passed array of coefficients (a*x^3 + b*x^2 + c*x + d) + pub fn hermite_cubic_interpolate_from_coeffs(t: f64, coeffs: &[f64; 4]) -> f64 { + let x = t; + let x2 = x * x; + let x3 = x2 * x; + coeffs[0] * x3 + coeffs[1] * x2 + coeffs[2] * x + coeffs[3] + } + + /// Calculates the derivative with respect to x at a given knot. + /// This mirrors the _ddx function in LHAPDF's C++ implementation. + pub fn calculate_ddx(data: &InterpData2D, ix: usize, iq2: usize) -> f64 + where + D: Data + RawDataClone + Clone, + { + let nxknots = data.grid[0].len(); + let x_coords = data.grid[0].as_slice().unwrap(); + let values = &data.values; + + let del1 = match ix { + 0 => 0.0, + i => x_coords[i] - x_coords[i - 1], + }; + + let del2 = match x_coords.get(ix + 1) { + Some(&next) => next - x_coords[ix], + None => 0.0, + }; + + if ix != 0 && ix != nxknots - 1 { + let lddx = (values[[ix, iq2]] - values[[ix - 1, iq2]]) / del1; + let rddx = (values[[ix + 1, iq2]] - values[[ix, iq2]]) / del2; + (lddx + rddx) / 2.0 + } else if ix == 0 { + (values[[ix + 1, iq2]] - values[[ix, iq2]]) / del2 + } else if ix == nxknots - 1 { + (values[[ix, iq2]] - values[[ix - 1, iq2]]) / del1 + } else { + panic!("Should not reach here: Invalid index for derivative calculation."); + } + } + + /// Computes the polynomial coefficients for bicubic interpolation, mirroring LHAPDF's C++ implementation. + fn compute_polynomial_coefficients(data: &InterpData2D) -> Vec + where + D: Data + RawDataClone + Clone, + { + let nxknots = data.grid[0].len(); + let nq2knots = data.grid[1].len(); + let values = &data.values; + + // The shape of the coefficients array: (nxknots-1) * nq2knots * 4 (for a,b,c,d) + let mut coeffs: Vec = vec![0.0; (nxknots - 1) * nq2knots * 4]; + + for ix in 0..nxknots - 1 { + for iq2 in 0..nq2knots { + let dx = + data.grid[0].as_slice().unwrap()[ix + 1] - data.grid[0].as_slice().unwrap()[ix]; + + let vl = values[[ix, iq2]]; + let vh = values[[ix + 1, iq2]]; + let vdl = Self::calculate_ddx(data, ix, iq2) * dx; + let vdh = Self::calculate_ddx(data, ix + 1, iq2) * dx; + + // polynomial coefficients + let a = vdh + vdl - 2.0 * vh + 2.0 * vl; + let b = 3.0 * vh - 3.0 * vl - 2.0 * vdl - vdh; + let c = vdl; + let d = vl; + + let base_idx = (ix * nq2knots + iq2) * 4; + coeffs[base_idx] = a; + coeffs[base_idx + 1] = b; + coeffs[base_idx + 2] = c; + coeffs[base_idx + 3] = d; + } + } + coeffs + } + + /// Performs bicubic interpolation using pre-computed coefficients. + fn interpolate_with_coeffs( + &self, + data: &InterpData2D, + ix: usize, + iq2: usize, + u: f64, + v: f64, + ) -> f64 + where + D: Data + RawDataClone + Clone, + { + let nq2knots = data.grid[1].len(); + + let base_idx_vl = (ix * nq2knots + iq2) * 4; + let coeffs_vl: [f64; 4] = self.coeffs[base_idx_vl..base_idx_vl + 4] + .try_into() + .unwrap(); + let vl = Self::hermite_cubic_interpolate_from_coeffs(u, &coeffs_vl); + + let base_idx_vh = (ix * nq2knots + iq2 + 1) * 4; + let coeffs_vh: [f64; 4] = self.coeffs[base_idx_vh..base_idx_vh + 4] + .try_into() + .unwrap(); + let vh = Self::hermite_cubic_interpolate_from_coeffs(u, &coeffs_vh); + + let q2_grid: &[f64] = data.grid[1].as_slice().unwrap(); + + let dq_1 = q2_grid[iq2 + 1] - q2_grid[iq2]; + + let vdl: f64; + let vdh: f64; + + if iq2 == 0 { + vdl = vh - vl; + let vhh_base_idx = (ix * nq2knots + iq2 + 2) * 4; + let coeffs_vhh: [f64; 4] = self.coeffs[vhh_base_idx..vhh_base_idx + 4] + .try_into() + .unwrap(); + let vhh = Self::hermite_cubic_interpolate_from_coeffs(u, &coeffs_vhh); + let dq_2 = 1.0 / (q2_grid[iq2 + 2] - q2_grid[iq2 + 1]); + vdh = (vdl + (vhh - vh) * dq_1 * dq_2) * 0.5; + } else if iq2 == nq2knots - 2 { + vdh = vh - vl; + let vll_base_idx = (ix * nq2knots + iq2 - 1) * 4; + let coeffs_vll: [f64; 4] = self.coeffs[vll_base_idx..vll_base_idx + 4] + .try_into() + .unwrap(); + let vll = Self::hermite_cubic_interpolate_from_coeffs(u, &coeffs_vll); + let dq_0 = 1.0 / (q2_grid[iq2] - q2_grid[iq2 - 1]); + vdl = (vdh + (vl - vll) * dq_1 * dq_0) * 0.5; + } else { + let vll_base_idx = (ix * nq2knots + iq2 - 1) * 4; + let coeffs_vll: [f64; 4] = self.coeffs[vll_base_idx..vll_base_idx + 4] + .try_into() + .unwrap(); + let vll = Self::hermite_cubic_interpolate_from_coeffs(u, &coeffs_vll); + let dq_0 = 1.0 / (q2_grid[iq2] - q2_grid[iq2 - 1]); + + let vhh_base_idx = (ix * nq2knots + iq2 + 2) * 4; + let coeffs_vhh: [f64; 4] = self.coeffs[vhh_base_idx..vhh_base_idx + 4] + .try_into() + .unwrap(); + let vhh = Self::hermite_cubic_interpolate_from_coeffs(u, &coeffs_vhh); + let dq_2 = 1.0 / (q2_grid[iq2 + 2] - q2_grid[iq2 + 1]); + + vdl = ((vh - vl) + (vl - vll) * dq_1 * dq_0) * 0.5; + vdh = ((vh - vl) + (vhh - vh) * dq_1 * dq_2) * 0.5; + } + + utils::hermite_cubic_interpolate(v, vl, vdl, vh, vdh) + } +} + +impl Strategy2D for LogBicubicInterpolation +where + D: Data + RawDataClone + Clone, +{ + fn init(&mut self, data: &InterpData2D) -> Result<(), ValidateError> { + let x_coords = data.grid[0].as_slice().unwrap(); + let y_coords = data.grid[1].as_slice().unwrap(); + + if x_coords.len() < 4 || y_coords.len() < 4 { + return Err(ValidateError::Other( + "Need at least 4x4 grid for bicubic interpolation".to_string(), + )); + } + + self.coeffs = Self::compute_polynomial_coefficients(data); + Ok(()) + } + + fn interpolate( + &self, + data: &InterpData2D, + point: &[f64; 2], + ) -> Result { + let [x, y] = *point; + + let x_coords = data.grid[0].as_slice().unwrap(); + let y_coords = data.grid[1].as_slice().unwrap(); + + let i = Self::find_bicubic_interval(x_coords, x)?; + let j = Self::find_bicubic_interval(y_coords, y)?; + + let dx = x_coords[i + 1] - x_coords[i]; + let dy = y_coords[j + 1] - y_coords[j]; + + if dx == 0.0 || dy == 0.0 { + return Err(InterpolateError::Other("Grid spacing is zero".to_string())); + } + + let u = (x - x_coords[i]) / dx; + let v = (y - y_coords[j]) / dy; + + let result = self.interpolate_with_coeffs(data, i, j, u, v); + + Ok(result) + } + + fn allow_extrapolate(&self) -> bool { + true + } +} + +/// LogTricubic interpolation strategy for PDF-like data +/// +/// This strategy implements tricubic interpolation with logarithmic coordinate scaling: +/// - x-coordinates are logarithmically spaced (e.g., 1e-9 to 1) +/// - y-coordinates are logarithmically spaced (e.g., Q² values) +/// - z-coordinates are logarithmically spaced (e.g., Mass Atomic A, AlphaS) +/// - w-values (PDF values) are interpolated using tricubic splines +/// +/// Tricubic interpolation uses a 4x4x4 grid of points around the interpolation point +/// and provides C1 continuity (continuous first derivatives). +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct LogTricubicInterpolation; + +impl LogTricubicInterpolation { + /// Returns the index i such that we can use points [i-1, i, i+1, i+2] for interpolation. + fn find_tricubic_interval(coords: &[f64], x: f64) -> Result { + // Find the interval [i, i+1] such that coords[i] <= x < coords[i+1] + let i = utils::find_interval_index(coords, x)?; + Ok(i) + } + + /// Cubic interpolation using a passed array of coefficients (a*x^3 + b*x^2 + c*x + d) + pub fn hermite_cubic_interpolate_from_coeffs(t: f64, coeffs: &[f64; 4]) -> f64 { + let x = t; + let x2 = x * x; + let x3 = x2 * x; + coeffs[0] * x3 + coeffs[1] * x2 + coeffs[2] * x + coeffs[3] + } + + /// Calculates the derivative with respect to x at a given knot. + pub fn calculate_ddx(data: &InterpData3D, ix: usize, iq2: usize, iz: usize) -> f64 + where + D: Data + RawDataClone + Clone, + { + let nxknots = data.grid[0].len(); + let x_coords = data.grid[0].as_slice().unwrap(); + let values = &data.values; + + let del1 = match ix { + 0 => 0.0, + i => x_coords[i] - x_coords[i - 1], + }; + + let del2 = match x_coords.get(ix + 1) { + Some(&next) => next - x_coords[ix], + None => 0.0, + }; + + if ix != 0 && ix != nxknots - 1 { + let lddx = (values[[ix, iq2, iz]] - values[[ix - 1, iq2, iz]]) / del1; + let rddx = (values[[ix + 1, iq2, iz]] - values[[ix, iq2, iz]]) / del2; + (lddx + rddx) / 2.0 + } else if ix == 0 { + (values[[ix + 1, iq2, iz]] - values[[ix, iq2, iz]]) / del2 + } else if ix == nxknots - 1 { + (values[[ix, iq2, iz]] - values[[ix - 1, iq2, iz]]) / del1 + } else { + panic!("Should not reach here: Invalid index for derivative calculation."); + } + } + + /// Calculates the derivative with respect to y at a given knot. + pub fn calculate_ddy(data: &InterpData3D, ix: usize, iq2: usize, iz: usize) -> f64 + where + D: Data + RawDataClone + Clone, + { + let nq2knots = data.grid[1].len(); + let q2_coords = data.grid[1].as_slice().unwrap(); + let values = &data.values; + + let del1 = match iq2 { + 0 => 0.0, + i => q2_coords[i] - q2_coords[i - 1], + }; + + let del2 = match q2_coords.get(iq2 + 1) { + Some(&next) => next - q2_coords[iq2], + None => 0.0, + }; + + if iq2 != 0 && iq2 != nq2knots - 1 { + let lddq = (values[[ix, iq2, iz]] - values[[ix, iq2 - 1, iz]]) / del1; + let rddq = (values[[ix, iq2 + 1, iz]] - values[[ix, iq2, iz]]) / del2; + (lddq + rddq) / 2.0 + } else if iq2 == 0 { + (values[[ix, iq2 + 1, iz]] - values[[ix, iq2, iz]]) / del2 + } else if iq2 == nq2knots - 1 { + (values[[ix, iq2, iz]] - values[[ix, iq2 - 1, iz]]) / del1 + } else { + panic!("Should not reach here: Invalid index for derivative calculation."); + } + } + + /// Calculates the derivative with respect to z at a given knot. + pub fn calculate_ddz(data: &InterpData3D, ix: usize, iq2: usize, iz: usize) -> f64 + where + D: Data + RawDataClone + Clone, + { + let nmu2knots = data.grid[2].len(); + let mu2_coords = data.grid[2].as_slice().unwrap(); + let values = &data.values; + + let del1 = match iz { + 0 => 0.0, + i => mu2_coords[i] - mu2_coords[i - 1], + }; + + let del2 = match mu2_coords.get(iz + 1) { + Some(&next) => next - mu2_coords[iz], + None => 0.0, + }; + + if iz != 0 && iz != nmu2knots - 1 { + let lddmu = (values[[ix, iq2, iz]] - values[[ix, iq2, iz - 1]]) / del1; + let rddmu = (values[[ix, iq2, iz + 1]] - values[[ix, iq2, iz]]) / del2; + (lddmu + rddmu) / 2.0 + } else if iz == 0 { + (values[[ix, iq2, iz + 1]] - values[[ix, iq2, iz]]) / del2 + } else if iz == nmu2knots - 1 { + (values[[ix, iq2, iz]] - values[[ix, iq2, iz - 1]]) / del1 + } else { + panic!("Should not reach here: Invalid index for derivative calculation."); + } + } + + fn hermite_tricubic_interpolate( + &self, + data: &InterpData3D, + indices: (usize, usize, usize), + coords: (f64, f64, f64), + derivatives: (f64, f64, f64), + ) -> f64 + where + D: Data + RawDataClone + Clone, + { + let (ix, iq2, iz) = indices; + let (u, v, w) = coords; + let (dx, dy, dz) = derivatives; + + let get = |dx, dy, dz| data.values[[ix + dx, iq2 + dy, iz + dz]]; + let ddx = |dx, dy, dz| Self::calculate_ddx(data, ix + dx, iq2 + dy, iz + dz); + let ddy = |dx, dy, dz| Self::calculate_ddy(data, ix + dx, iq2 + dy, iz + dz); + let ddz = |dx, dy, dz| Self::calculate_ddz(data, ix + dx, iq2 + dy, iz + dz); + + let interp_y: [[f64; 2]; 4] = [0, 1] + .iter() + .flat_map(|&y_offset| { + [0, 1].iter().map(move |&z_offset| { + let (f0, f1) = (get(0, y_offset, z_offset), get(1, y_offset, z_offset)); + let (d0, d1) = ( + ddx(0, y_offset, z_offset) * dx, + ddx(1, y_offset, z_offset) * dx, + ); + let interp_val = Self::cubic_interpolate(u, f0, d0, f1, d1); + + let (df0, df1) = ( + ddy(0, y_offset, z_offset) * dy, + ddy(1, y_offset, z_offset) * dy, + ); + let interp_deriv = (1.0 - u) * df0 + u * df1; + + [interp_val, interp_deriv] + }) + }) + .collect::>() + .try_into() + .unwrap(); + + let interp_z: [[f64; 2]; 2] = [0, 1] + .iter() + .enumerate() + .map(|(iz_, &z_offset)| { + let (f0, f1) = (interp_y[iz_][0], interp_y[2 + iz_][0]); + let (d0, d1) = (interp_y[iz_][1], interp_y[2 + iz_][1]); + let interp_val = Self::cubic_interpolate(v, f0, d0, f1, d1); + + let calc_z_deriv = |y_offset| { + let (df0, df1) = ( + ddz(0, y_offset, z_offset) * dz, + ddz(1, y_offset, z_offset) * dz, + ); + (1.0 - u) * df0 + u * df1 + }; + + let interp_deriv = (1.0 - v) * calc_z_deriv(0) + v * calc_z_deriv(1); + [interp_val, interp_deriv] + }) + .collect::>() + .try_into() + .unwrap(); + + let (f0, f1) = (interp_z[0][0], interp_z[1][0]); + let (d0, d1) = (interp_z[0][1], interp_z[1][1]); + Self::cubic_interpolate(w, f0, d0, f1, d1) + } + + /// Hermite cubic interpolation with derivatives + fn cubic_interpolate(t: f64, f0: f64, f0_prime: f64, f1: f64, f1_prime: f64) -> f64 { + let t2 = t * t; + let t3 = t2 * t; + + // Hermite basis functions + let h00 = 2.0 * t3 - 3.0 * t2 + 1.0; + let h10 = t3 - 2.0 * t2 + t; + let h01 = -2.0 * t3 + 3.0 * t2; + let h11 = t3 - t2; + + h00 * f0 + h10 * f0_prime + h01 * f1 + h11 * f1_prime + } +} + +impl Strategy3D for LogTricubicInterpolation +where + D: Data + RawDataClone + Clone, +{ + fn init(&mut self, data: &InterpData3D) -> Result<(), ValidateError> { + let x_coords = data.grid[0].as_slice().unwrap(); + let y_coords = data.grid[1].as_slice().unwrap(); + let z_coords = data.grid[2].as_slice().unwrap(); + + if x_coords.len() < 4 || y_coords.len() < 4 || z_coords.len() < 4 { + return Err(ValidateError::Other( + "Need at least 4x4x4 grid for tricubic interpolation".to_string(), + )); + } + + // Uses the Hermite approach instead of coefficient precomputation. + // This is more straightforward and avoids the complex 64x64 matrix. + Ok(()) + } + + fn interpolate( + &self, + data: &InterpData3D, + point: &[f64; 3], + ) -> Result { + let [x, y, z] = *point; + + let x_coords = data.grid[0].as_slice().unwrap(); + let y_coords = data.grid[1].as_slice().unwrap(); + let z_coords = data.grid[2].as_slice().unwrap(); + + let i = Self::find_tricubic_interval(x_coords, x)?; + let j = Self::find_tricubic_interval(y_coords, y)?; + let k = Self::find_tricubic_interval(z_coords, z)?; + + let dx = x_coords[i + 1] - x_coords[i]; + let dy = y_coords[j + 1] - y_coords[j]; + let dz = z_coords[k + 1] - z_coords[k]; + + if dx == 0.0 || dy == 0.0 || dz == 0.0 { + return Err(InterpolateError::Other("Grid spacing is zero".to_string())); + } + + let u = (x - x_coords[i]) / dx; + let v = (y - y_coords[j]) / dy; + let w = (z - z_coords[k]) / dz; + + let result = self.hermite_tricubic_interpolate(data, (i, j, k), (u, v, w), (dx, dy, dz)); + + Ok(result) + } + + fn allow_extrapolate(&self) -> bool { + true + } +} + +/// Implements cubic interpolation for alpha_s values in log-Q2 space. +/// +/// This strategy handles the specific extrapolation and interpolation rules +/// for alpha_s as defined in LHAPDF. +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct AlphaSCubicInterpolation; + +impl AlphaSCubicInterpolation { + /// Get the index of the closest Q2 knot row <= q2 + /// + /// If the value is >= q2_max, return (i_max-1). + fn ilogq2below(data: &InterpData1D, logq2: f64) -> usize + where + D: Data + RawDataClone + Clone, + { + let logq2s = data.grid[0].as_slice().unwrap(); + if logq2 < *logq2s.first().unwrap() { + panic!( + "Q2 value {} is lower than lowest-Q2 grid point at {}", + logq2.exp(), + logq2s.first().unwrap().exp() + ); + } + if logq2 > *logq2s.last().unwrap() { + panic!( + "Q2 value {} is higher than highest-Q2 grid point at {}", + logq2.exp(), + logq2s.last().unwrap().exp() + ); + } + + let idx = logq2s.partition_point(|&x| x < logq2); + + if idx == logq2s.len() { + idx - 1 + } else if (logq2s[idx] - logq2).abs() < 1e-9 { + if idx == logq2s.len() - 1 && logq2s.len() >= 2 { + idx - 1 + } else { + idx + } + } else { + idx - 1 + } + } + + /// Forward derivative w.r.t. logQ2 + fn ddlogq_forward(data: &InterpData1D, i: usize) -> f64 + where + D: Data + RawDataClone + Clone, + { + let logq2s = data.grid[0].as_slice().unwrap(); + let alphas = data.values.as_slice().unwrap(); + (alphas[i + 1] - alphas[i]) / (logq2s[i + 1] - logq2s[i]) + } + + /// Backward derivative w.r.t. logQ2 + fn ddlogq_backward(data: &InterpData1D, i: usize) -> f64 + where + D: Data + RawDataClone + Clone, + { + let logq2s = data.grid[0].as_slice().unwrap(); + let alphas = data.values.as_slice().unwrap(); + (alphas[i] - alphas[i - 1]) / (logq2s[i] - logq2s[i - 1]) + } + + /// Central (avg of forward and backward) derivative w.r.t. logQ2 + fn ddlogq_central(data: &InterpData1D, i: usize) -> f64 + where + D: Data + RawDataClone + Clone, + { + 0.5 * (Self::ddlogq_forward(data, i) + Self::ddlogq_backward(data, i)) + } +} + +impl Strategy1D for AlphaSCubicInterpolation +where + D: Data + RawDataClone + Clone, +{ + fn interpolate( + &self, + data: &InterpData1D, + point: &[f64; 1], + ) -> Result { + let logq2 = point[0]; + let logq2s = data.grid[0].as_slice().unwrap(); + let alphas = data.values.as_slice().unwrap(); + + if logq2 < *logq2s.first().unwrap() { + let mut next_point = 1; + while logq2s[0] == logq2s[next_point] { + next_point += 1; + } + let dlogq2 = logq2s[next_point] - logq2s[0]; + let dlogas = (alphas[next_point] / alphas[0]).ln(); + let loggrad = dlogas / dlogq2; + return Ok(alphas[0] * (loggrad * (logq2 - logq2s[0])).exp()); + } + + if logq2 > *logq2s.last().unwrap() { + return Ok(*alphas.last().unwrap()); + } + + let i = Self::ilogq2below(data, logq2); + + // Calculate derivatives + let didlogq2: f64; + let di1dlogq2: f64; + if i == 0 { + didlogq2 = Self::ddlogq_forward(data, i); + di1dlogq2 = Self::ddlogq_central(data, i + 1); + } else if i == logq2s.len() - 2 { + didlogq2 = Self::ddlogq_central(data, i); + di1dlogq2 = Self::ddlogq_backward(data, i + 1); + } else { + didlogq2 = Self::ddlogq_central(data, i); + di1dlogq2 = Self::ddlogq_central(data, i + 1); + } + + // Calculate alpha_s + let dlogq2 = logq2s[i + 1] - logq2s[i]; + let tlogq2 = (logq2 - logq2s[i]) / dlogq2; + Ok(utils::hermite_cubic_interpolate( + tlogq2, + alphas[i], + didlogq2 * dlogq2, + alphas[i + 1], + di1dlogq2 * dlogq2, + )) + } + + fn allow_extrapolate(&self) -> bool { + true + } +} + +/// Implements a global N-dimensional interpolation using Chebyshev polynomials with logarithmic +/// coordinate scaling. +/// +/// This strategy, inspired by the method described in arXiv:2112.09703, first transforms the input +/// coordinates to their natural logarithms, and then fits a single, high-degree Chebyshev polynomial +/// to the entire dataset in the log-transformed space. +/// +/// Key features: +/// - **Logarithmic Scaling**: Coordinates are transformed via `x -> ln(x)` before interpolation. +/// - **Global Nature**: The interpolation at any point depends on all data points in the grid. +/// - **High Degree**: The degree of the interpolating polynomial is `N-1`, where `N` is the +/// number of grid points in each dimension. +/// - **Grid Requirement**: For optimal stability and to avoid Runge's phenomenon, the grid +/// points should correspond to the roots or extrema of Chebyshev polynomials. +#[derive(Debug, Clone)] +pub struct LogChebyshevInterpolation { + // Pre-computed weights for the barycentric formula for each dimension. + weights: [Vec; DIM], + // Grid points in the t-domain [-1, 1] for each dimension. + t_coords: [Vec; DIM], +} + +impl Default for LogChebyshevInterpolation { + fn default() -> Self { + Self { + weights: std::array::from_fn(|_| Vec::new()), + t_coords: std::array::from_fn(|_| Vec::new()), + } + } +} + +impl Serialize for LogChebyshevInterpolation { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("LogChebyshevInterpolation", 2)?; + state.serialize_field("weights", &self.weights.as_slice())?; + state.serialize_field("t_coords", &self.t_coords.as_slice())?; + state.end() + } +} + +impl<'de, const DIM: usize> Deserialize<'de> for LogChebyshevInterpolation { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct Helper { + weights: Vec>, + t_coords: Vec>, + } + + let helper = Helper::deserialize(deserializer)?; + let weights = helper.weights.try_into().map_err(|v: Vec>| { + serde::de::Error::invalid_length(v.len(), &"an array of the correct length") + })?; + let t_coords = helper.t_coords.try_into().map_err(|v: Vec>| { + serde::de::Error::invalid_length(v.len(), &"an array of the correct length") + })?; + + Ok(Self { weights, t_coords }) + } +} + +impl LogChebyshevInterpolation { + /// Computes the barycentric weights for a given set of Chebyshev points. + /// The formula for the weights is `w_j = (-1)^j * delta_j`, where `delta_j` + /// is 1/2 for the first and last points, and 1 otherwise. + fn compute_barycentric_weights(n: usize) -> Vec { + let mut weights = vec![1.0; n]; + (0..n).for_each(|j| { + if j % 2 == 1 { + weights[j] = -1.0; + } + }); + weights[0] *= 0.5; + if n > 1 { + weights[n - 1] *= 0.5; + } + weights + } + + /// Computes normalized barycentric coefficients for interpolation + /// Returns a vector of coefficients that sum to 1 + fn barycentric_coefficients(t: f64, t_coords: &[f64], weights: &[f64]) -> Vec { + let mut coeffs = vec![0.0; t_coords.len()]; + + for (j, &t_j) in t_coords.iter().enumerate() { + if (t - t_j).abs() < 1e-15 { + coeffs[j] = 1.0; + return coeffs; + } + } + + let mut terms = Vec::with_capacity(t_coords.len()); + for (j, &t_j) in t_coords.iter().enumerate() { + terms.push(weights[j] / (t - t_j)); + } + + let sum: f64 = terms.iter().sum(); + for (j, &term) in terms.iter().enumerate() { + coeffs[j] = term / sum; + } + + coeffs + } + + /// Legacy barycentric interpolation method (kept for compatibility) + fn barycentric_interpolate(t: f64, t_coords: &[f64], f_values: &[f64], weights: &[f64]) -> f64 { + let mut numer = 0.0; + let mut denom = 0.0; + + for (j, &t_j) in t_coords.iter().enumerate() { + if (t - t_j).abs() < 1e-15 { + return f_values[j]; + } + + let term = weights[j] / (t - t_j); + numer += term * f_values[j]; + denom += term; + } + + numer / denom + } +} + +impl Strategy1D for LogChebyshevInterpolation<1> +where + D: Data + RawDataClone + Clone, +{ + fn init(&mut self, data: &InterpData1D) -> Result<(), ValidateError> { + let x_coords = data.grid[0].as_slice().unwrap(); + let n = x_coords.len(); + if n < 2 { + return Err(ValidateError::Other( + "LogChebyshevInterpolation requires at least 2 grid points.".to_string(), + )); + } + + self.t_coords[0] = (0..n) + .map(|j| (PI * (n - 1 - j) as f64 / (n - 1) as f64).cos()) + .collect(); + + self.weights[0] = Self::compute_barycentric_weights(n); + + Ok(()) + } + + fn interpolate( + &self, + data: &InterpData1D, + point: &[f64; 1], + ) -> Result { + let x = point[0]; + let x_coords = data.grid[0].as_slice().unwrap(); + let f_values = data.values.as_slice().unwrap(); + + let x_min = *x_coords.first().unwrap(); + let x_max = *x_coords.last().unwrap(); + + if (x_max - x_min).abs() < 1e-15 { + return Ok(f_values[0]); + } + let t = 2.0 * (x - x_min) / (x_max - x_min) - 1.0; + + Ok(Self::barycentric_interpolate( + t, + &self.t_coords[0], + f_values, + &self.weights[0], + )) + } + + fn allow_extrapolate(&self) -> bool { + true + } +} + +impl Strategy2D for LogChebyshevInterpolation<2> +where + D: Data + RawDataClone + Clone, +{ + fn init(&mut self, data: &InterpData2D) -> Result<(), ValidateError> { + for dim in 0..2 { + let x_coords = data.grid[dim].as_slice().unwrap(); + let n = x_coords.len(); + if n < 2 { + return Err(ValidateError::Other( + "LogChebyshevInterpolation requires at least 2 grid points per dimension." + .to_string(), + )); + } + self.t_coords[dim] = (0..n) + .map(|j| (PI * (n - 1 - j) as f64 / (n - 1) as f64).cos()) + .collect(); + self.weights[dim] = Self::compute_barycentric_weights(n); + } + Ok(()) + } + + fn interpolate( + &self, + data: &InterpData2D, + point: &[f64; 2], + ) -> Result { + let [x, y] = *point; + let x_coords = data.grid[0].as_slice().unwrap(); + let y_coords = data.grid[1].as_slice().unwrap(); + + let x_min = *x_coords.first().unwrap(); + let x_max = *x_coords.last().unwrap(); + let y_min = *y_coords.first().unwrap(); + let y_max = *y_coords.last().unwrap(); + + let t_x = 2.0 * (x - x_min) / (x_max - x_min) - 1.0; + let t_y = 2.0 * (y - y_min) / (y_max - y_min) - 1.0; + + let x_coeffs = Self::barycentric_coefficients(t_x, &self.t_coords[0], &self.weights[0]); + let y_coeffs = Self::barycentric_coefficients(t_y, &self.t_coords[1], &self.weights[1]); + + let mut result = 0.0; + for (i, &x_coeff) in x_coeffs.iter().enumerate() { + for (j, &y_coeff) in y_coeffs.iter().enumerate() { + result += x_coeff * y_coeff * data.values[[i, j]]; + } + } + + Ok(result) + } + + fn allow_extrapolate(&self) -> bool { + true + } +} + +impl Strategy3D for LogChebyshevInterpolation<3> +where + D: Data + RawDataClone + Clone, +{ + fn init(&mut self, data: &InterpData3D) -> Result<(), ValidateError> { + for dim in 0..3 { + let x_coords = data.grid[dim].as_slice().unwrap(); + let n = x_coords.len(); + if n < 2 { + return Err(ValidateError::Other( + "LogChebyshevInterpolation requires at least 2 grid points per dimension." + .to_string(), + )); + } + self.t_coords[dim] = (0..n) + .map(|j| (PI * (n - 1 - j) as f64 / (n - 1) as f64).cos()) + .collect(); + self.weights[dim] = Self::compute_barycentric_weights(n); + } + Ok(()) + } + + fn interpolate( + &self, + data: &InterpData3D, + point: &[f64; 3], + ) -> Result { + let [x, y, z] = *point; + let x_coords = data.grid[0].as_slice().unwrap(); + let y_coords = data.grid[1].as_slice().unwrap(); + let z_coords = data.grid[2].as_slice().unwrap(); + + let x_min = *x_coords.first().unwrap(); + let x_max = *x_coords.last().unwrap(); + let y_min = *y_coords.first().unwrap(); + let y_max = *y_coords.last().unwrap(); + let z_min = *z_coords.first().unwrap(); + let z_max = *z_coords.last().unwrap(); + + let t_x = 2.0 * (x - x_min) / (x_max - x_min) - 1.0; + let t_y = 2.0 * (y - y_min) / (y_max - y_min) - 1.0; + let t_z = 2.0 * (z - z_min) / (z_max - z_min) - 1.0; + + let x_coeffs = Self::barycentric_coefficients(t_x, &self.t_coords[0], &self.weights[0]); + let y_coeffs = Self::barycentric_coefficients(t_y, &self.t_coords[1], &self.weights[1]); + let z_coeffs = Self::barycentric_coefficients(t_z, &self.t_coords[2], &self.weights[2]); + + let mut result = 0.0; + for (i, &x_coeff) in x_coeffs.iter().enumerate() { + for (j, &y_coeff) in y_coeffs.iter().enumerate() { + for (k, &z_coeff) in z_coeffs.iter().enumerate() { + result += x_coeff * y_coeff * z_coeff * data.values[[i, j, k]]; + } + } + } + + Ok(result) + } + + fn allow_extrapolate(&self) -> bool { + true + } +} + +/// Implements a global N-dimensional batch interpolation using Chebyshev polynomials +/// with logarithmic coordinate scaling. +/// +/// This strategy is optimized for interpolating multiple points at once by leveraging +/// matrix operations with `ndarray`. +/// +/// TODO: Potentially merge this with `LogChebyshevInterpolation`. +#[derive(Debug, Clone)] +pub struct LogChebyshevBatchInterpolation { + // Pre-computed weights for the barycentric formula for each dimension. + weights: [Vec; DIM], + // Grid points in the t-domain [-1, 1] for each dimension. + t_coords: [Vec; DIM], +} + +impl Default for LogChebyshevBatchInterpolation { + fn default() -> Self { + Self { + weights: std::array::from_fn(|_| Vec::new()), + t_coords: std::array::from_fn(|_| Vec::new()), + } + } +} + +impl Serialize for LogChebyshevBatchInterpolation { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("LogChebyshevBatchInterpolation", 2)?; + state.serialize_field("weights", &self.weights.as_slice())?; + state.serialize_field("t_coords", &self.t_coords.as_slice())?; + state.end() + } +} + +impl<'de, const DIM: usize> Deserialize<'de> for LogChebyshevBatchInterpolation { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct Helper { + weights: Vec>, + t_coords: Vec>, + } + + let helper = Helper::deserialize(deserializer)?; + let weights = helper.weights.try_into().map_err(|v: Vec>| { + serde::de::Error::invalid_length(v.len(), &"an array of the correct length") + })?; + let t_coords = helper.t_coords.try_into().map_err(|v: Vec>| { + serde::de::Error::invalid_length(v.len(), &"an array of the correct length") + })?; + + Ok(Self { weights, t_coords }) + } +} + +impl LogChebyshevBatchInterpolation { + /// Computes the barycentric weights for a given set of Chebyshev points. + /// The formula for the weights is `w_j = (-1)^j * delta_j`, where `delta_j` + /// is 1/2 for the first and last points, and 1 otherwise. + fn compute_barycentric_weights(n: usize) -> Vec { + let mut weights = vec![1.0; n]; + (0..n).for_each(|j| { + if j % 2 == 1 { + weights[j] = -1.0; + } + }); + weights[0] *= 0.5; + + if n > 1 { + weights[n - 1] *= 0.5; + } + + weights + } + + /// Compute barycentric coefficients for multiple points in batch + fn barycentric_coefficients( + t_values: &[f64], + t_coords: &[f64], + weights: &[f64], + ) -> Array2 { + let num_points = t_values.len(); + let num_coords = t_coords.len(); + let mut coeffs = Array2::::zeros((num_points, num_coords)); + + for (p, &t) in t_values.iter().enumerate() { + let mut found_exact = false; + for (j, &t_j) in t_coords.iter().enumerate() { + if (t - t_j).abs() < 1e-15 { + coeffs[[p, j]] = 1.0; + found_exact = true; + break; + } + } + + if !found_exact { + let mut terms = Vec::with_capacity(num_coords); + for (j, &t_j) in t_coords.iter().enumerate() { + terms.push(weights[j] / (t - t_j)); + } + + let sum: f64 = terms.iter().sum(); + + for (j, &term) in terms.iter().enumerate() { + coeffs[[p, j]] = term / sum; + } + } + } + + coeffs + } +} + +impl LogChebyshevBatchInterpolation<1> { + pub fn init(&mut self, data: &InterpData1D) -> Result<(), ValidateError> + where + D: Data + RawDataClone + Clone, + { + let x_coords = data.grid[0].as_slice().unwrap(); + let n = x_coords.len(); + if n < 2 { + return Err(ValidateError::Other( + "LogChebyshevBatchInterpolation requires at least 2 grid points.".to_string(), + )); + } + + self.t_coords[0] = (0..n) + .map(|j| (PI * (n - 1 - j) as f64 / (n - 1) as f64).cos()) + .collect(); + + self.weights[0] = Self::compute_barycentric_weights(n); + + Ok(()) + } + + pub fn interpolate( + &self, + data: &InterpData1D, + points: &[[f64; 1]], + ) -> Result, InterpolateError> + where + D: Data + RawDataClone + Clone, + { + let x_coords = data.grid[0].as_slice().unwrap(); + let f_values = data.values.to_owned(); + + let x_min = *x_coords.first().unwrap(); + let x_max = *x_coords.last().unwrap(); + + let mut t_x_vals = Vec::with_capacity(points.len()); + for &[x] in points { + t_x_vals.push(2.0 * (x - x_min) / (x_max - x_min) - 1.0); + } + + let c_x = Self::barycentric_coefficients(&t_x_vals, &self.t_coords[0], &self.weights[0]); + let results = c_x.dot(&f_values); + + Ok(results.to_vec()) + } +} + +impl LogChebyshevBatchInterpolation<2> { + pub fn init(&mut self, data: &InterpData2D) -> Result<(), ValidateError> + where + D: Data + RawDataClone + Clone, + { + for dim in 0..2 { + let x_coords = data.grid[dim].as_slice().unwrap(); + let n = x_coords.len(); + if n < 2 { + return Err(ValidateError::Other( + "LogChebyshevBatchInterpolation requires at least 2 grid points per dimension." + .to_string(), + )); + } + self.t_coords[dim] = (0..n) + .map(|j| (PI * (n - 1 - j) as f64 / (n - 1) as f64).cos()) + .collect(); + self.weights[dim] = Self::compute_barycentric_weights(n); + } + + Ok(()) + } + + pub fn interpolate( + &self, + data: &InterpData2D, + points: &[[f64; 2]], + ) -> Result, InterpolateError> + where + D: Data + RawDataClone + Clone, + { + let x_coords = data.grid[0].as_slice().unwrap(); + let y_coords = data.grid[1].as_slice().unwrap(); + + let x_min = *x_coords.first().unwrap(); + let x_max = *x_coords.last().unwrap(); + let y_min = *y_coords.first().unwrap(); + let y_max = *y_coords.last().unwrap(); + + let mut t_x_vals = Vec::with_capacity(points.len()); + let mut t_y_vals = Vec::with_capacity(points.len()); + for &[x, y] in points { + t_x_vals.push(2.0 * (x - x_min) / (x_max - x_min) - 1.0); + t_y_vals.push(2.0 * (y - y_min) / (y_max - y_min) - 1.0); + } + + let c_x = Self::barycentric_coefficients(&t_x_vals, &self.t_coords[0], &self.weights[0]); + let c_y = Self::barycentric_coefficients(&t_y_vals, &self.t_coords[1], &self.weights[1]); + let v = data.values.to_owned(); + let results = (&c_x.dot(&v) * &c_y).sum_axis(Axis(1)); + + Ok(results.to_vec()) + } +} + +impl LogChebyshevBatchInterpolation<3> { + pub fn init(&mut self, data: &InterpData3D) -> Result<(), ValidateError> + where + D: Data + RawDataClone + Clone, + { + for dim in 0..3 { + let x_coords = data.grid[dim].as_slice().unwrap(); + let n = x_coords.len(); + if n < 2 { + return Err(ValidateError::Other( + "LogChebyshevBatchInterpolation requires at least 2 grid points per dimension." + .to_string(), + )); + } + self.t_coords[dim] = (0..n) + .map(|j| (PI * (n - 1 - j) as f64 / (n - 1) as f64).cos()) + .collect(); + self.weights[dim] = Self::compute_barycentric_weights(n); + } + Ok(()) + } + + pub fn interpolate( + &self, + data: &InterpData3D, + points: &[[f64; 3]], + ) -> Result, InterpolateError> + where + D: Data + RawDataClone + Clone, + { + let x_coords = data.grid[0].as_slice().unwrap(); + let y_coords = data.grid[1].as_slice().unwrap(); + let z_coords = data.grid[2].as_slice().unwrap(); + + let x_min = *x_coords.first().unwrap(); + let x_max = *x_coords.last().unwrap(); + let y_min = *y_coords.first().unwrap(); + let y_max = *y_coords.last().unwrap(); + let z_min = *z_coords.first().unwrap(); + let z_max = *z_coords.last().unwrap(); + + let mut t_x_vals = Vec::with_capacity(points.len()); + let mut t_y_vals = Vec::with_capacity(points.len()); + let mut t_z_vals = Vec::with_capacity(points.len()); + + for &[x, y, z] in points { + t_x_vals.push(2.0 * (x - x_min) / (x_max - x_min) - 1.0); + t_y_vals.push(2.0 * (y - y_min) / (y_max - y_min) - 1.0); + t_z_vals.push(2.0 * (z - z_min) / (z_max - z_min) - 1.0); + } + + let c_x = Self::barycentric_coefficients(&t_x_vals, &self.t_coords[0], &self.weights[0]); + let c_y = Self::barycentric_coefficients(&t_y_vals, &self.t_coords[1], &self.weights[1]); + let c_z = Self::barycentric_coefficients(&t_z_vals, &self.t_coords[2], &self.weights[2]); + + let v = &data.values; + + let num_points = points.len(); + let (nx, ny, nz) = (x_coords.len(), y_coords.len(), z_coords.len()); + + let v_flat = v.to_owned().into_shape_with_order((nx, ny * nz)).unwrap(); + let temp1 = c_x.dot(&v_flat); + let temp1_3d = temp1.into_shape_with_order((num_points, ny, nz)).unwrap(); + + let mut results = Vec::with_capacity(num_points); + for p in 0..num_points { + let temp_slice = temp1_3d.index_axis(Axis(0), p); + let cy_slice = c_y.row(p); + let cz_slice = c_z.row(p); + + let temp2 = cy_slice.dot(&temp_slice); + let result = cz_slice.dot(&temp2); + results.push(result); + } + + Ok(results) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use itertools::Itertools; + use ndarray::{Array1, Array2, Array3, OwnedRepr}; + use ninterp::data::{InterpData1D, InterpData2D}; + use ninterp::interpolator::{Extrapolate, InterpND}; + use ninterp::prelude::Interpolator; + use ninterp::strategy::Linear; + use std::f64::consts::PI; + + const EPSILON: f64 = 1e-9; + + fn assert_close(actual: f64, expected: f64, tolerance: f64) { + assert!( + (actual - expected).abs() < tolerance, + "Expected {}, got {} (diff: {})", + expected, + actual, + (actual - expected).abs() + ); + } + + fn create_target_data_2d(max_num: i32) -> Vec { + (1..=max_num) + .flat_map(|i| (1..=max_num).map(move |j| (i * j) as f64)) + .collect() + } + + fn create_logspaced(start: f64, stop: f64, n: usize) -> Vec { + (0..n) + .map(|value| { + let t = value as f64 / (n - 1) as f64; + start * (stop / start).powf(t) + }) + .collect() + } + + fn create_test_data_1d( + q2_values: Vec, + alphas_vals: Vec, + ) -> InterpData1D> { + InterpData1D::new(Array1::from(q2_values), Array1::from(alphas_vals)).unwrap() + } + + fn create_test_data_2d( + x_coords: Vec, + y_coords: Vec, + values: Vec, + ) -> InterpData2D> { + let shape = (x_coords.len(), y_coords.len()); + let values_array = Array2::from_shape_vec(shape, values).unwrap(); + InterpData2D::new(x_coords.into(), y_coords.into(), values_array).unwrap() + } + + fn create_test_data_3d( + x_coords: Vec, + y_coords: Vec, + z_coords: Vec, + values: Vec, + ) -> InterpData3D> { + let shape = (x_coords.len(), y_coords.len(), z_coords.len()); + let values_array = Array3::from_shape_vec(shape, values).unwrap(); + InterpData3D::new( + x_coords.into(), + y_coords.into(), + z_coords.into(), + values_array, + ) + .unwrap() + } + + fn create_cheby_grid(n_points: i32, x_min: f64, x_max: f64) -> Vec { + let u_min = x_min.ln(); + let u_max = x_max.ln(); + (0..n_points) + .map(|j| { + let t_j = (PI * (n_points - 1 - j) as f64 / (n_points - 1) as f64).cos(); + let u_j = u_min + (u_max - u_min) * (t_j + 1.0) / 2.0; + u_j.exp() + }) + .collect::>() + } + + #[test] + fn test_linear_interpolate() { + let test_cases = [ + // (x1, x2, y1, y2, x, expected) + (0.0, 1.0, 0.0, 10.0, 0.5, 5.0), + (0.0, 10.0, 0.0, 100.0, 2.5, 25.0), + (0.0, 1.0, 0.0, 10.0, 0.0, 0.0), // At start endpoint + (0.0, 1.0, 0.0, 10.0, 1.0, 10.0), // At end endpoint + (5.0, 5.0, 10.0, 20.0, 5.0, 10.0), // x1 == x2 case + ]; + + for (x1, x2, y1, y2, x, expected) in test_cases { + let result = BilinearInterpolation::linear_interpolate(x1, x2, y1, y2, x); + assert_close(result, expected, EPSILON); + } + } + + #[test] + fn test_bilinear_interpolation() { + let data = create_test_data_2d( + vec![0.0, 1.0, 2.0], + vec![0.0, 1.0, 2.0], + vec![0.0, 1.0, 2.0, 1.0, 2.0, 3.0, 2.0, 3.0, 4.0], + ); + + let test_cases = [ + ([0.5, 0.5], 1.0), + ([1.0, 1.0], 2.0), // Grid point + ([0.25, 0.75], 1.0), + ]; + + for (point, expected) in test_cases { + let result = BilinearInterpolation.interpolate(&data, &point).unwrap(); + assert_close(result, expected, EPSILON); + } + } + + #[test] + fn test_log_bilinear_interpolation() { + let data = create_test_data_2d( + vec![1.0f64.ln(), 10.0f64.ln(), 100.0f64.ln()], + vec![1.0f64.ln(), 10.0f64.ln(), 100.0f64.ln()], + vec![0.0, 1.0, 2.0, 1.0, 2.0, 3.0, 2.0, 3.0, 4.0], + ); + LogBilinearInterpolation.init(&data).unwrap(); + + let test_cases = [ + ([3.16227766f64.ln(), 3.16227766f64.ln()], 1.0), // sqrt(10) + ([10.0f64.ln(), 10.0f64.ln()], 2.0), // Grid point + ([1.77827941f64.ln(), 5.62341325f64.ln()], 1.0), // 10^0.25, 10^0.75 + ]; + + for (point, expected) in test_cases { + let result = LogBilinearInterpolation.interpolate(&data, &point).unwrap(); + assert_close(result, expected, EPSILON); + } + } + + #[test] + fn test_log_tricubic_interpolation() { + let x_coords = create_logspaced(1e-5, 1e-3, 6); + let y_coords = create_logspaced(1e2, 1e4, 6); + let z_coords = vec![1.0, 5.0, 25.0, 100.0, 150.0, 200.0]; + let values: Vec = x_coords + .iter() + .cartesian_product(y_coords.iter()) + .cartesian_product(z_coords.iter()) + .map(|((&a, &b), &c)| a * b * c) + .collect(); + + let values_ln: Vec = values.iter().map(|val| val.ln()).collect(); + let interp_data_ln = create_test_data_3d( + x_coords.iter().map(|v| v.ln()).collect(), + y_coords.iter().map(|v| v.ln()).collect(), + z_coords.iter().map(|v| v.ln()).collect(), + values_ln.clone(), + ); + + let mut strategy = LogTricubicInterpolation; + strategy.init(&interp_data_ln).unwrap(); + + let point: [f64; 3] = [1e-4, 2e3, 25.0]; + let log_point = [point[0].ln(), point[1].ln(), point[2].ln()]; + let expected: f64 = point.iter().product(); + let result = strategy + .interpolate(&interp_data_ln, &log_point) + .unwrap() + .exp(); + assert_close(result, expected, EPSILON); + + let interp_data_arr = + Array3::from_shape_vec((x_coords.len(), y_coords.len(), z_coords.len()), values) + .unwrap(); + let nd_interp = InterpND::new( + vec![x_coords.into(), y_coords.into(), z_coords.into()], + interp_data_arr.into_dyn(), + Linear, + Extrapolate::Error, + ) + .unwrap(); + let nd_interp_res = nd_interp.interpolate(&point).unwrap(); + assert_close(nd_interp_res, expected, EPSILON); + } + + #[test] + fn test_alphas_cubic_interpolation() { + let q_values = [1.0f64, 2.0, 3.0, 4.0, 5.0]; + let alphas_vals = vec![0.1, 0.11, 0.12, 0.13, 0.14]; + let logq2_values: Vec = q_values.iter().map(|&q| (q * q).ln()).collect(); + let data = create_test_data_1d(logq2_values, alphas_vals); + let alphas_cubic = AlphaSCubicInterpolation; + + // Test within interpolation range + let result = alphas_cubic.interpolate(&data, &[2.25f64.ln()]).unwrap(); + assert!(result > 0.1 && result < 0.14); + + // Test at grid point + let result = alphas_cubic.interpolate(&data, &[4.0f64.ln()]).unwrap(); + assert_close(result, 0.11, EPSILON); + + // Test extrapolation below range + let result = alphas_cubic.interpolate(&data, &[0.5f64.ln()]).unwrap(); + assert!(result < 0.1); + + // Test extrapolation above range + let result = alphas_cubic.interpolate(&data, &[30.0f64.ln()]).unwrap(); + assert_close(result, 0.14, EPSILON); + } + + #[test] + fn test_find_bicubic_interval() { + let coords = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + + let test_cases = [ + (1.5, Ok(0)), + (3.5, Ok(2)), + (2.0, Ok(1)), // At knot point + (1.0, Ok(0)), // At boundary + (4.99, Ok(3)), // Near boundary + (0.5, Err(())), // Out of bounds + (5.5, Err(())), // Out of bounds + ]; + + for (value, expected) in test_cases { + let result = LogBicubicInterpolation::find_bicubic_interval(&coords, value); + match expected { + Ok(expected_idx) => assert_eq!(result.unwrap(), expected_idx), + Err(_) => assert!(result.is_err()), + } + } + } + + #[test] + fn test_hermite_cubic_interpolate_from_coeffs() { + let test_cases = [ + // Linear function x: coeffs = [0, 0, 1, 0] + ([0.0, 0.0, 1.0, 0.0], 0.5, 0.5), + ([0.0, 0.0, 1.0, 0.0], 1.0, 1.0), + // Constant function 5: coeffs = [0, 0, 0, 5] + ([0.0, 0.0, 0.0, 5.0], 0.5, 5.0), + // Cubic function x^3: coeffs = [1, 0, 0, 0] + ([1.0, 0.0, 0.0, 0.0], 2.0, 8.0), + ([1.0, 0.0, 0.0, 0.0], 0.5, 0.125), + // Complex polynomial 2x^3 - 3x^2 + x + 4 + ([2.0, -3.0, 1.0, 4.0], 1.0, 4.0), + ([2.0, -3.0, 1.0, 4.0], 0.0, 4.0), + ([2.0, -3.0, 1.0, 4.0], 2.0, 10.0), + ]; + + for (coeffs, x, expected) in test_cases { + let result = LogBicubicInterpolation::hermite_cubic_interpolate_from_coeffs(x, &coeffs); + assert_close(result, expected, EPSILON); + } + } + + #[test] + fn test_log_bicubic_interpolation() { + let target_data = create_target_data_2d(4); + let data = create_test_data_2d( + vec![1.0f64.ln(), 10.0f64.ln(), 100.0f64.ln(), 1000.0f64.ln()], + vec![1.0f64.ln(), 10.0f64.ln(), 100.0f64.ln(), 1000.0f64.ln()], + target_data, + ); + + let mut log_bicubic = LogBicubicInterpolation::default(); + log_bicubic.init(&data).unwrap(); + + let test_cases = [ + ([10.0f64.ln(), 10.0f64.ln()], 4.0), // Grid point + ([3.16227766f64.ln(), 3.16227766f64.ln()], 2.25), // sqrt(10) + ([31.6227766f64.ln(), 31.6227766f64.ln()], 6.25), // 10^1.5 + ]; + + for (point, expected) in test_cases { + let result = log_bicubic.interpolate(&data, &point).unwrap(); + assert_close(result, expected, EPSILON); + } + } + + #[test] + fn test_ddlogq_derivatives() { + let data = create_test_data_1d( + vec![1.0f64.ln(), 2.0f64.ln(), 3.0f64.ln(), 4.0f64.ln()], + vec![0.1, 0.2, 0.3, 0.4], + ); + + let expected_forward = 0.1 / (2.0f64.ln() - 1.0f64.ln()); + assert_close( + AlphaSCubicInterpolation::ddlogq_forward(&data, 0), + expected_forward, + EPSILON, + ); + + let expected_backward = 0.1 / (2.0f64.ln() - 1.0f64.ln()); + assert_close( + AlphaSCubicInterpolation::ddlogq_backward(&data, 1), + expected_backward, + EPSILON, + ); + + let expected_central = + 0.5 * (0.1 / (3.0f64.ln() - 2.0f64.ln()) + 0.1 / (2.0f64.ln() - 1.0f64.ln())); + assert_close( + AlphaSCubicInterpolation::ddlogq_central(&data, 1), + expected_central, + EPSILON, + ); + } + + #[test] + fn test_ilogq2below() { + let data = create_test_data_1d( + vec![ + 1.0f64.ln(), + 2.0f64.ln(), + 3.0f64.ln(), + 4.0f64.ln(), + 5.0f64.ln(), + ], + vec![0.1, 0.2, 0.3, 0.4, 0.5], + ); + + let test_cases = [ + (1.5f64.ln(), 0), + (2.0f64.ln(), 1), + (3.9f64.ln(), 2), // Within range + (1.0f64.ln(), 0), + (5.0f64.ln(), 3), // At boundaries + ]; + + for (q2_val, expected_idx) in test_cases { + assert_eq!( + AlphaSCubicInterpolation::ilogq2below(&data, q2_val), + expected_idx + ); + } + + let data_small = create_test_data_1d(vec![1.0f64.ln(), 2.0f64.ln()], vec![0.1, 0.2]); + assert_eq!( + AlphaSCubicInterpolation::ilogq2below(&data_small, 2.0f64.ln()), + 0 + ); + + let data_with_mid = create_test_data_1d( + vec![1.0f64.ln(), 2.0f64.ln(), 3.0f64.ln()], + vec![0.1, 0.2, 0.3], + ); + assert_eq!( + AlphaSCubicInterpolation::ilogq2below(&data_with_mid, 2.0f64.ln()), + 1 + ); + + let data_single = create_test_data_1d(vec![1.0f64.ln()], vec![0.1]); + + let result = std::panic::catch_unwind(|| { + AlphaSCubicInterpolation::ilogq2below(&data_single, 0.5f64.ln()); + }); + assert!(result.is_err()); + + let result = std::panic::catch_unwind(|| { + AlphaSCubicInterpolation::ilogq2below(&data_single, 1.5f64.ln()); + }); + assert!(result.is_err()); + } + + #[test] + fn test_log_chebyshev_interpolation_1d() { + let n = 21; + let x_min: f64 = 0.1; + let x_max: f64 = 10.0; + let x_coords = create_cheby_grid(n, x_min, x_max); + + let f_values: Vec = x_coords.iter().map(|&x| x.ln()).collect(); + let data = create_test_data_1d(x_coords.iter().map(|v| v.ln()).collect(), f_values); + let mut cheby = LogChebyshevInterpolation::<1>::default(); + cheby.init(&data).unwrap(); + + let x_test: f64 = 2.5; + let expected = x_test.ln(); + let result = cheby.interpolate(&data, &[x_test.ln()]).unwrap(); + assert_close(result, expected, EPSILON); + + let x_test_grid = data.grid[0].as_slice().unwrap()[n as usize / 2]; + let expected_grid = x_test_grid; + let result_grid = cheby.interpolate(&data, &[x_test_grid]).unwrap(); + assert_close(result_grid, expected_grid, EPSILON); + } + + #[test] + fn test_log_chebyshev_interpolation_2d() { + let n = 11; + let x_coords = create_cheby_grid(n, 0.1, 10.0); + let y_coords = create_cheby_grid(n, 0.1, 10.0); + + let f_values: Vec = x_coords + .iter() + .flat_map(|&x| y_coords.iter().map(move |&y| x.ln() + y.ln())) + .collect(); + + let data = create_test_data_2d( + x_coords.iter().map(|v| v.ln()).collect(), + y_coords.iter().map(|v| v.ln()).collect(), + f_values, + ); + let mut cheby = LogChebyshevInterpolation::<2>::default(); + cheby.init(&data).unwrap(); + + let x_test: f64 = 2.5; + let y_test: f64 = 3.5; + let expected = x_test.ln() + y_test.ln(); + let result = cheby + .interpolate(&data, &[x_test.ln(), y_test.ln()]) + .unwrap(); + + assert_close(result, expected, EPSILON); + } + + #[test] + fn test_log_chebyshev_interpolation_3d() { + let n = 7; + let x_coords = create_cheby_grid(n, 0.1, 10.0); + let y_coords = create_cheby_grid(n, 0.1, 10.0); + let z_coords = create_cheby_grid(n, 0.1, 10.0); + + let f_values: Vec = x_coords + .iter() + .cartesian_product(y_coords.iter()) + .cartesian_product(z_coords.iter()) + .map(|((&x, &y), &z)| x.ln() + y.ln() + z.ln()) + .collect(); + + let data = create_test_data_3d( + x_coords.iter().map(|v| v.ln()).collect(), + y_coords.iter().map(|v| v.ln()).collect(), + z_coords.iter().map(|v| v.ln()).collect(), + f_values, + ); + let mut cheby = LogChebyshevInterpolation::<3>::default(); + cheby.init(&data).unwrap(); + + let x_test: f64 = 2.5; + let y_test: f64 = 3.5; + let z_test: f64 = 4.5; + let expected = x_test.ln() + y_test.ln() + z_test.ln(); + let result = cheby + .interpolate(&data, &[x_test.ln(), y_test.ln(), z_test.ln()]) + .unwrap(); + + assert_close(result, expected, EPSILON); + } + + #[test] + fn test_log_chebyshev_batch_interpolation_1d() { + let n = 21; + let x_min: f64 = 0.1; + let x_max: f64 = 10.0; + let x_coords = create_cheby_grid(n, x_min, x_max); + + let f_values: Vec = x_coords.iter().map(|&x| x.ln()).collect(); + let data = create_test_data_1d(x_coords.iter().map(|v| v.ln()).collect(), f_values); + let mut cheby = LogChebyshevBatchInterpolation::<1>::default(); + cheby.init(&data).unwrap(); + + let test_points = [[2.5f64.ln()], [5.0f64.ln()], [7.5f64.ln()]]; + let expected: Vec = test_points.iter().map(|p| p[0]).collect(); + let results = cheby.interpolate(&data, &test_points).unwrap(); + + for (res, exp) in results.iter().zip(expected.iter()) { + assert_close(*res, *exp, EPSILON); + } + } + + #[test] + fn test_log_chebyshev_batch_interpolation_2d() { + let n = 11; + let x_coords = create_cheby_grid(n, 0.1, 10.0); + let y_coords = create_cheby_grid(n, 0.1, 10.0); + + let f_values: Vec = x_coords + .iter() + .flat_map(|&x| y_coords.iter().map(move |&y| x.ln() + y.ln())) + .collect(); + + let data = create_test_data_2d( + x_coords.iter().map(|v| v.ln()).collect(), + y_coords.iter().map(|v| v.ln()).collect(), + f_values, + ); + let mut cheby = LogChebyshevBatchInterpolation::<2>::default(); + cheby.init(&data).unwrap(); + + let test_points = [ + [2.5f64.ln(), 3.5f64.ln()], + [5.0f64.ln(), 6.0f64.ln()], + [7.5f64.ln(), 8.5f64.ln()], + ]; + let expected: Vec = test_points.iter().map(|p| p[0] + p[1]).collect(); + let results = cheby.interpolate(&data, &test_points).unwrap(); + + for (res, exp) in results.iter().zip(expected.iter()) { + assert_close(*res, *exp, EPSILON); + } + } + + #[test] + fn test_log_chebyshev_batch_interpolation_3d() { + let n = 7; + let x_coords = create_cheby_grid(n, 0.1, 10.0); + let y_coords = create_cheby_grid(n, 0.1, 10.0); + let z_coords = create_cheby_grid(n, 0.1, 10.0); + + let f_values: Vec = x_coords + .iter() + .cartesian_product(y_coords.iter()) + .cartesian_product(z_coords.iter()) + .map(|((&x, &y), &z)| x.ln() + y.ln() + z.ln()) + .collect(); + + let data = create_test_data_3d( + x_coords.iter().map(|v| v.ln()).collect(), + y_coords.iter().map(|v| v.ln()).collect(), + z_coords.iter().map(|v| v.ln()).collect(), + f_values, + ); + let mut cheby = LogChebyshevBatchInterpolation::<3>::default(); + cheby.init(&data).unwrap(); + + let test_points = [ + [2.5f64.ln(), 3.5f64.ln(), 4.5f64.ln()], + [5.0f64.ln(), 6.0f64.ln(), 7.0f64.ln()], + [7.5f64.ln(), 8.5f64.ln(), 9.5f64.ln()], + ]; + let expected: Vec = test_points.iter().map(|p| p[0] + p[1] + p[2]).collect(); + let results = cheby.interpolate(&data, &test_points).unwrap(); + + for (res, exp) in results.iter().zip(expected.iter()) { + assert_close(*res, *exp, EPSILON); + } + } +} diff --git a/neopdf_legacy/src/subgrid.rs b/neopdf_legacy/src/subgrid.rs new file mode 100644 index 0000000..deebaef --- /dev/null +++ b/neopdf_legacy/src/subgrid.rs @@ -0,0 +1,417 @@ +//! This module defines the [`SubGrid`] struct and its implementation for PDF grid handling. +//! +//! # Contents +//! +//! - [`ParamRange`], [`RangeParameters`]: Parameter range types for grid axes. +//! - [`SubGrid`]: Represents a region of phase space with a consistent grid and provides +//! methods for subgrid logic. + +use ndarray::{s, Array1, Array6, ArrayView2}; +use serde::{Deserialize, Serialize}; + +use super::interpolator::InterpolationConfig; + +/// Represents the valid range of a parameter, with a minimum and maximum value. +#[derive(Debug, Clone, Copy, Deserialize, Serialize)] +pub struct ParamRange { + /// The minimum value of the parameter. + pub min: f64, + /// The maximum value of the parameter. + pub max: f64, +} + +impl ParamRange { + /// Creates a new `ParamRange`. + /// + /// # Arguments + /// + /// * `min` - The minimum value. + /// * `max` - The maximum value. + pub fn new(min: f64, max: f64) -> Self { + Self { min, max } + } + + /// Checks if a given value is within the parameter range (inclusive). + /// + /// # Arguments + /// + /// * `value` - The value to check. + /// + /// # Returns + /// + /// `true` if the value is within the range, `false` otherwise. + pub fn contains(&self, value: f64) -> bool { + value >= self.min && value <= self.max + } +} + +/// Represents the parameter ranges for `x` and `q2`. +pub struct RangeParameters { + /// The range for the nucleon numbers `A`. + pub nucleons: ParamRange, + /// The range for the AlphaS values `as`. + pub alphas: ParamRange, + /// The range for the transverse momentum `kT`. + pub kt: ParamRange, + /// The range for the momentum fraction `x`. + pub x: ParamRange, + /// The range for the energy scale squared `q2`. + pub q2: ParamRange, +} + +impl RangeParameters { + /// Creates a new `RangeParameters`. + /// + /// # Arguments + /// + /// * `nucleons` - The `ParamRange` for the nuleon numbers `A`. + /// * `alphas` - The `ParamRange` for the strong coupling `as`. + /// * `kt` - The `ParamRange` for the transverse momentum `kT`. + /// * `x` - The `ParamRange` for the momentum fraction `x`. + /// * `q2` - The `ParamRange` for the energy scale `q2`. + pub fn new( + nucleons: ParamRange, + alphas: ParamRange, + kt: ParamRange, + x: ParamRange, + q2: ParamRange, + ) -> Self { + Self { + nucleons, + alphas, + kt, + x, + q2, + } + } +} + +/// Stores the PDF grid data for a single subgrid. +/// +/// A subgrid represents a region of the phase space with a consistent +/// grid of `x` and `Q²` values. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SubGrid { + /// Array of `x` values (momentum fraction). + pub xs: Array1, + /// Array of `Q²` values (energy scale squared). + pub q2s: Array1, + /// Array of `kT` values (transverse momentum). + pub kts: Array1, + /// 6-dimensional grid data: [nucleons, alphas, pids, kT, x, Q²]. + pub grid: Array6, + /// Array of nucleon number values. + pub nucleons: Array1, + /// Array of alpha_s values. + pub alphas: Array1, + /// The valid range for the `nucleons` parameter in this subgrid. + pub nucleons_range: ParamRange, + /// The valid range for the `AlphaS` parameter in this subgrid. + pub alphas_range: ParamRange, + /// The valid range for the `kT` parameter in this subgrid. + pub kt_range: ParamRange, + /// The valid range for the `x` parameter in this subgrid. + pub x_range: ParamRange, + /// The valid range for the `q2` parameter in this subgrid. + pub q2_range: ParamRange, +} + +impl SubGrid { + /// Creates a new `SubGrid` from raw data. + /// + /// # Arguments + /// + /// * `nucleon_numbers` - A vector of nucleon numbers. + /// * `alphas_values` - A vector of alpha_s values. + /// * `kt_subgrid` - A vector of `kT` values. + /// * `xs` - A vector of `x` values. + /// * `q2s` - A vector of `q2` values. + /// * `nflav` - The number of quark flavors. + /// * `grid_data` - A flat vector of grid data points. + /// + /// # Panics + /// + /// Panics if the grid data cannot be reshaped to the expected dimensions. + pub fn new( + nucleon_numbers: Vec, + alphas_values: Vec, + kt_subgrid: Vec, + x_subgrid: Vec, + q2_subgrid: Vec, + nflav: usize, + grid_data: Vec, + ) -> Self { + let xs_range = ParamRange::new(*x_subgrid.first().unwrap(), *x_subgrid.last().unwrap()); + let q2s_range = ParamRange::new(*q2_subgrid.first().unwrap(), *q2_subgrid.last().unwrap()); + let kts_range = ParamRange::new(*kt_subgrid.first().unwrap(), *kt_subgrid.last().unwrap()); + let ncs_range = ParamRange::new( + *nucleon_numbers.first().unwrap(), + *nucleon_numbers.last().unwrap(), + ); + let as_range = ParamRange::new( + *alphas_values.first().unwrap(), + *alphas_values.last().unwrap(), + ); + + let subgrid = Array6::from_shape_vec( + ( + nucleon_numbers.len(), + alphas_values.len(), + kt_subgrid.len(), + x_subgrid.len(), + q2_subgrid.len(), + nflav, + ), + grid_data, + ) + .expect("Failed to create grid") + .permuted_axes([0, 1, 5, 2, 3, 4]) + .as_standard_layout() + .to_owned(); + + Self { + xs: Array1::from_vec(x_subgrid), + q2s: Array1::from_vec(q2_subgrid), + kts: Array1::from_vec(kt_subgrid), + grid: subgrid, + nucleons: Array1::from_vec(nucleon_numbers), + alphas: Array1::from_vec(alphas_values), + nucleons_range: ncs_range, + alphas_range: as_range, + kt_range: kts_range, + x_range: xs_range, + q2_range: q2s_range, + } + } + + /// Checks if a point (..., `x`, `q2`) is within the boundaries of this subgrid. + /// + /// # Arguments + /// + /// * `points` - A slice of coordinates. The order is assumed to be + /// `(A, alpha_s, kT, x, Q2)`, with dimensions only present if they are part of + /// the grid. + /// + /// # Returns + /// + /// `true` if the point is within the subgrid, `false` otherwise. + pub fn contains_point(&self, points: &[f64]) -> bool { + let (expected_len, ranges) = match self.interpolation_config() { + InterpolationConfig::TwoD => (2, vec![]), + InterpolationConfig::ThreeDNucleons => (3, vec![&self.nucleons_range]), + InterpolationConfig::ThreeDAlphas => (3, vec![&self.alphas_range]), + InterpolationConfig::ThreeDKt => (3, vec![&self.kt_range]), + InterpolationConfig::FourDNucleonsAlphas => { + (4, vec![&self.nucleons_range, &self.alphas_range]) + } + InterpolationConfig::FourDNucleonsKt => (4, vec![&self.nucleons_range, &self.kt_range]), + InterpolationConfig::FourDAlphasKt => (4, vec![&self.alphas_range, &self.kt_range]), + InterpolationConfig::FiveD => ( + 5, + vec![&self.nucleons_range, &self.alphas_range, &self.kt_range], + ), + }; + + points.len() == expected_len + && self.x_range.contains(points[expected_len - 2]) + && self.q2_range.contains(points[expected_len - 1]) + && ranges + .iter() + .zip(points) + .all(|(range, &point)| range.contains(point)) + } + + /// Calculates the squared distance from a point to the subgrid's bounding box. + pub fn distance_to_point(&self, points: &[f64]) -> f64 { + self.parameter_ranges() + .iter() + .zip(points) + .map(|(range, &point)| match point { + p if p < range.min => (range.min - p) * (range.min - p), + p if p > range.max => (p - range.max) * (p - range.max), + _ => 0.0, + }) + .sum() + } + + /// Gathers the parameter ranges for the subgrid based on its configuration. + fn parameter_ranges(&self) -> Vec { + let mut ranges = match self.interpolation_config() { + InterpolationConfig::TwoD => vec![], + InterpolationConfig::ThreeDNucleons => vec![self.nucleons_range], + InterpolationConfig::ThreeDAlphas => vec![self.alphas_range], + InterpolationConfig::ThreeDKt => vec![self.kt_range], + InterpolationConfig::FourDNucleonsAlphas => { + vec![self.nucleons_range, self.alphas_range] + } + InterpolationConfig::FourDNucleonsKt => vec![self.nucleons_range, self.kt_range], + InterpolationConfig::FourDAlphasKt => vec![self.alphas_range, self.kt_range], + InterpolationConfig::FiveD => { + vec![self.nucleons_range, self.alphas_range, self.kt_range] + } + }; + ranges.extend([self.x_range, self.q2_range]); + ranges + } + + /// Gets the interpolation configuration for this subgrid. + pub fn interpolation_config(&self) -> InterpolationConfig { + InterpolationConfig::from_dimensions(self.nucleons.len(), self.alphas.len(), self.kts.len()) + } + + /// Gets the parameter ranges for this subgrid. + pub fn ranges(&self) -> RangeParameters { + RangeParameters::new( + self.nucleons_range, + self.alphas_range, + self.kt_range, + self.x_range, + self.q2_range, + ) + } + + /// Gets a 2D slice of the grid for interpolation. + /// + /// This method is only valid for 2D interpolation configurations. + /// + /// # Arguments + /// + /// * `pid_index` - The index of the particle ID (flavor). + /// + /// # Panics + /// + /// Panics if called on a subgrid that is not 2D. + pub fn grid_slice(&self, pid_index: usize) -> ArrayView2<'_, f64> { + match self.interpolation_config() { + InterpolationConfig::TwoD => self.grid.slice(s![0, 0, pid_index, 0, .., ..]), + _ => panic!("grid_slice only valid for 2D interpolation"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn mock_subgrid_2d() -> SubGrid { + SubGrid::new( + vec![1.0], + vec![0.118], + vec![0.0], + vec![0.1, 0.2], + vec![1.0, 2.0], + 1, + vec![1.0, 2.0, 3.0, 4.0], + ) + } + + fn mock_subgrid_3d_nucleons() -> SubGrid { + SubGrid::new( + vec![1.0, 2.0], + vec![0.118], + vec![0.0], + vec![0.1, 0.2], + vec![1.0, 2.0], + 1, + vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + ) + } + + #[test] + fn test_param_range() { + let range = ParamRange::new(1.0, 10.0); + assert!(range.contains(5.0)); + assert!(range.contains(1.0)); + assert!(range.contains(10.0)); + assert!(!range.contains(15.0)); + assert!(!range.contains(0.5)); + } + + #[test] + fn test_range_parameters_new() { + let rp = RangeParameters::new( + ParamRange::new(1.0, 2.0), + ParamRange::new(0.1, 0.2), + ParamRange::new(0.0, 1.0), + ParamRange::new(1e-4, 1.0), + ParamRange::new(1.0, 1000.0), + ); + assert_eq!(rp.nucleons.min, 1.0); + assert_eq!(rp.q2.max, 1000.0); + } + + #[test] + fn test_subgrid_new() { + let sg = mock_subgrid_2d(); + assert_eq!(sg.xs.len(), 2); + assert_eq!(sg.q2s.len(), 2); + assert_eq!(sg.grid.shape(), &[1, 1, 1, 1, 2, 2]); + assert_eq!(sg.x_range.min, 0.1); + assert_eq!(sg.x_range.max, 0.2); + assert_eq!(sg.q2_range.min, 1.0); + assert_eq!(sg.q2_range.max, 2.0); + } + + #[test] + fn test_contains_point_2d() { + let sg = mock_subgrid_2d(); + assert!(sg.contains_point(&[0.15, 1.5])); + assert!(sg.contains_point(&[0.1, 1.0])); + assert!(!sg.contains_point(&[0.5, 1.5])); + assert!(!sg.contains_point(&[0.15, 5.0])); + assert!(!sg.contains_point(&[0.15])); + } + + #[test] + fn test_contains_point_3d_nucleons() { + let sg = mock_subgrid_3d_nucleons(); + assert!(sg.contains_point(&[1.5, 0.15, 1.5])); + assert!(!sg.contains_point(&[5.0, 0.15, 1.5])); + } + + #[test] + fn test_distance_to_point() { + let sg = mock_subgrid_2d(); + assert_eq!(sg.distance_to_point(&[0.15, 1.5]), 0.0); + let d = sg.distance_to_point(&[0.3, 1.5]); + assert!(d > 0.0); + } + + #[test] + fn test_ranges() { + let sg = mock_subgrid_2d(); + let r = sg.ranges(); + assert_eq!(r.x.min, 0.1); + assert_eq!(r.x.max, 0.2); + assert_eq!(r.q2.min, 1.0); + assert_eq!(r.q2.max, 2.0); + } + + #[test] + fn test_grid_slice_2d() { + let sg = mock_subgrid_2d(); + let slice = sg.grid_slice(0); + assert_eq!(slice.shape(), &[2, 2]); + } + + #[test] + #[should_panic] + fn test_grid_slice_panics_for_3d() { + let sg = mock_subgrid_3d_nucleons(); + sg.grid_slice(0); + } + + #[test] + fn test_interpolation_config_from_subgrid() { + let sg_2d = mock_subgrid_2d(); + assert!(matches!( + sg_2d.interpolation_config(), + InterpolationConfig::TwoD + )); + let sg_3d = mock_subgrid_3d_nucleons(); + assert!(matches!( + sg_3d.interpolation_config(), + InterpolationConfig::ThreeDNucleons + )); + } +} diff --git a/neopdf_legacy/src/utils.rs b/neopdf_legacy/src/utils.rs new file mode 100644 index 0000000..cfd010c --- /dev/null +++ b/neopdf_legacy/src/utils.rs @@ -0,0 +1,92 @@ +//! This module provides utility functions for interpolation and grid operations. +//! +//! It includes helpers for finding interval indices in coordinate arrays and for +//! performing 1D cubic interpolation using Hermite basis functions. Finds the index +//! of the interval in a sorted coordinate array that contains the given value. +/// +/// This function performs a binary search to efficiently locate the correct interval. +/// +/// # Arguments +/// +/// * `coords` - A sorted slice of f64 values representing the coordinates. +/// * `value` - The f64 value for which to find the interval. +/// +/// # Returns +/// +/// A `Result` containing the 0-based index of the left bound of the interval if successful. +/// Returns an `InterpolateError::ExtrapolateError` if the value is outside the bounds +/// of the `coords` array. +pub fn find_interval_index( + coords: &[f64], + value: f64, +) -> Result { + // Check bounds + if value < coords[0] || value > coords[coords.len() - 1] { + return Err(ninterp::error::InterpolateError::ExtrapolateError( + "Out of Bounds!".to_string(), + )); + } + + // Handle exact match with last coordinate + if value == coords[coords.len() - 1] { + return Ok(coords.len() - 2); + } + + // Binary search for the interval + let mut left = 0; + let mut right = coords.len() - 1; + + while left < right { + let mid = (left + right) / 2; + if coords[mid] <= value { + left = mid + 1; + } else { + right = mid; + } + } + + Ok(left - 1) +} + +/// One-dimensional cubic interpolation using Hermite basis functions. +/// +/// @arg t is the fractional distance of the evaluation x into the dx +/// interval. @arg vl and @arg vh are the function values at the low and +/// high edges of the interval. @arg vdl and @arg vdh are linearly +/// extrapolated value changes from the product of dx and the discrete low- +/// and high-edge derivative estimates. +pub fn hermite_cubic_interpolate(t: f64, vl: f64, vdl: f64, vh: f64, vdh: f64) -> f64 { + let t2 = t * t; + let t3 = t2 * t; + + let p0 = (2.0 * t3 - 3.0 * t2 + 1.0) * vl; + let m0 = (t3 - 2.0 * t2 + t) * vdl; + let p1 = (-2.0 * t3 + 3.0 * t2) * vh; + let m1 = (t3 - t2) * vdh; + + p0 + m0 + p1 + m1 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_find_interval_index() { + let coords = vec![0.0, 1.0, 2.0, 3.0, 4.0]; + + // Test within bounds + assert_eq!(find_interval_index(&coords, 0.5).unwrap(), 0); + assert_eq!(find_interval_index(&coords, 1.0).unwrap(), 1); + assert_eq!(find_interval_index(&coords, 1.5).unwrap(), 1); + assert_eq!(find_interval_index(&coords, 3.9).unwrap(), 3); + + // Test at boundaries + assert_eq!(find_interval_index(&coords, 0.0).unwrap(), 0); + assert_eq!(find_interval_index(&coords, 4.0).unwrap(), 3); + + // Test out of bounds + assert!(find_interval_index(&coords, -0.1).is_err()); + assert!(find_interval_index(&coords, 4.1).is_err()); + } +} diff --git a/neopdf_legacy/src/writer.rs b/neopdf_legacy/src/writer.rs new file mode 100644 index 0000000..06101fa --- /dev/null +++ b/neopdf_legacy/src/writer.rs @@ -0,0 +1,516 @@ +//! This module provides utilities for serializing, compressing, and efficiently accessing PDF grid data. +//! +//! It defines types and methods for writing and reading collections of [`GridArray`]s to and from +//! compressed files, supporting both eager and lazy access patterns. The module is designed for +//! efficient storage and retrieval of large PDF sets, with shared metadata and support for random +//! access to individual members. +//! +//! # Main Features +//! +//! - Compression and decompression of multiple [`GridArray`]s with shared metadata using LZ4 and bincode +//! serialization. +//! - Random access to individual grid members without loading the entire collection into memory. +//! - Extraction of metadata without full decompression. +//! - Lazy iteration over grid members for memory-efficient processing of large sets. +//! +//! # Key Types +//! +//! - [`GridArrayWithMetadata`]: Container for a grid and its associated metadata. +//! - [`GridArrayCollection`]: Static interface for compressing and decompressing collections of grids. +//! - [`GridArrayReader`]: Provides random access to individual grids in a compressed file. +//! - [`LazyGridArrayIterator`]: Enables lazy, sequential iteration over grid members. +//! +//! See the documentation for each type for more details on available methods and usage patterns. +use std::env; +use std::fs::File; +use std::io::{BufReader, BufWriter, Read, Write}; +use std::path::Path; +use std::sync::Arc; + +use git_version::git_version; +use lz4_flex::frame::{FrameDecoder, FrameEncoder}; + +use super::gridpdf::GridArray; +use super::metadata::MetaData; + +const GIT_VERSION: &str = git_version!( + args = ["--always", "--dirty", "--long", "--tags"], + cargo_prefix = "cargo:", + fallback = "unknown" +); +const CODE_VERSION: &str = env!("CARGO_PKG_VERSION"); + +/// Container for a [`GridArray`] with a shared reference to its associated metadata. +/// +/// Used to bundle grid data and metadata together for convenient access after decompression +/// or random access. +#[derive(Debug)] +pub struct GridArrayWithMetadata { + pub grid: GridArray, + pub metadata: Arc, +} + +/// Static interface for compressing and decompressing of [`GridArray`]s with shared metadata. +/// +/// Provides methods for writing, reading, and extracting metadata from compressed files. +pub struct GridArrayCollection; + +impl GridArrayCollection { + /// Compresses and writes a collection of [`GridArray`]s and shared metadata to a file. + /// + /// # Arguments + /// + /// * `grids` - Slice of grid arrays to compress. + /// * `metadata` - Shared metadata for all grids. + /// * `path` - Output file path. + /// + /// # Returns + /// + /// `Ok(())` on success, or an error if writing fails. + pub fn compress>( + grids: &[&GridArray], + metadata: &MetaData, + path: P, + ) -> Result<(), Box> { + let file = File::create(path)?; + let buf_writer = BufWriter::new(file); + let mut encoder = FrameEncoder::new(buf_writer); + + let mut metadata_mut = metadata.as_latest(); + metadata_mut.git_version = GIT_VERSION.to_string(); + metadata_mut.code_version = CODE_VERSION.to_string(); + + let updated_metadata = MetaData::new_v1(metadata_mut); + let metadata_serialized = bincode::serialize(&updated_metadata)?; + let metadata_size = metadata_serialized.len() as u64; + + let metadata_size_bytes = bincode::serialize(&metadata_size)?; + encoder.write_all(&metadata_size_bytes)?; + encoder.write_all(&metadata_serialized)?; + + // Write number of grids + let count = grids.len() as u64; + let count_bytes = bincode::serialize(&count)?; + encoder.write_all(&count_bytes)?; + + // Serialize all grids first + let mut serialized_grids = Vec::new(); + for grid in grids { + let serialized = bincode::serialize(grid)?; + serialized_grids.push(serialized); + } + + // Calculate offsets relative to start of data section + let mut offsets = Vec::new(); + let mut current_offset = 0u64; + + // Each grid entry has: 8 bytes for size + data + for serialized in &serialized_grids { + offsets.push(current_offset); + current_offset += 8; // size field + current_offset += serialized.len() as u64; + } + + // Write offset table size and offsets + let offset_table_size = (serialized_grids.len() * 8) as u64; + let offset_table_size_bytes = bincode::serialize(&offset_table_size)?; + encoder.write_all(&offset_table_size_bytes)?; + + for offset in &offsets { + let offset_bytes = bincode::serialize(offset)?; + encoder.write_all(&offset_bytes)?; + } + + // Write grid data + for serialized in &serialized_grids { + let size = serialized.len() as u64; + let size_bytes = bincode::serialize(&size)?; + encoder.write_all(&size_bytes)?; + encoder.write_all(serialized)?; + } + + encoder.finish()?; + Ok(()) + } + + /// Decompresses and loads all [`GridArray`]s and shared metadata from a file. + /// + /// # Arguments + /// + /// * `path` - Input file path. + /// + /// # Returns + /// + /// A vector of [`GridArrayWithMetadata`] on success, or an error if reading fails. + pub fn decompress>( + path: P, + ) -> Result, Box> { + let file = File::open(path)?; + let buf_reader = BufReader::new(file); + let mut decoder = FrameDecoder::new(buf_reader); + + let mut decompressed = Vec::new(); + decoder.read_to_end(&mut decompressed)?; + + let mut cursor = std::io::Cursor::new(decompressed); + + // Read versioned metadata + let metadata_size: u64 = bincode::deserialize_from(&mut cursor)?; + let mut metadata_bytes = vec![0u8; metadata_size as usize]; + cursor.read_exact(&mut metadata_bytes)?; + + // Deserialize versioned metadata and convert to latest + let versioned_metadata: MetaData = bincode::deserialize(&metadata_bytes)?; + let shared_metadata = Arc::new(versioned_metadata); + let count: u64 = bincode::deserialize_from(&mut cursor)?; + + // Read offset table size (but don't skip it!) + let _offset_table_size: u64 = bincode::deserialize_from(&mut cursor)?; + + // Read the actual offsets + let mut offsets = Vec::with_capacity(count as usize); + for _ in 0..count { + let offset: u64 = bincode::deserialize_from(&mut cursor)?; + offsets.push(offset); + } + + // Now read the grid data + let mut grids = Vec::with_capacity(count as usize); + for _ in 0..count { + let size: u64 = bincode::deserialize_from(&mut cursor)?; + let mut grid_bytes = vec![0u8; size as usize]; + cursor.read_exact(&mut grid_bytes)?; + + let grid: GridArray = bincode::deserialize(&grid_bytes)?; + grids.push(GridArrayWithMetadata { + grid, + metadata: Arc::clone(&shared_metadata), + }); + } + + Ok(grids) + } + + /// Extracts just the metadata from a compressed file without loading the grids. + /// + /// # Arguments + /// + /// * `path` - Input file path. + /// + /// # Returns + /// + /// The [`MetaData`] struct on success, or an error if reading fails. + pub fn extract_metadata>( + path: P, + ) -> Result> { + let file = File::open(path)?; + let buf_reader = BufReader::new(file); + let mut decoder = FrameDecoder::new(buf_reader); + + let mut decompressed = Vec::new(); + decoder.read_to_end(&mut decompressed)?; + + let mut cursor = std::io::Cursor::new(decompressed); + + let metadata_size: u64 = bincode::deserialize_from(&mut cursor)?; + let mut metadata_bytes = vec![0u8; metadata_size as usize]; + cursor.read_exact(&mut metadata_bytes)?; + let metadata: MetaData = bincode::deserialize(&metadata_bytes)?; + + Ok(metadata) + } +} + +/// Provides random access to individual [`GridArray`]s in a compressed file without loading the entire collection. +/// +/// Useful for efficient access to large PDF sets where only a subset of members is needed. +pub struct GridArrayReader { + data: Vec, + metadata: Arc, + offsets: Vec, + count: u64, + data_start: u64, +} + +impl GridArrayReader { + /// Creates a new reader from a file, enabling random access to grid members. + /// + /// # Arguments + /// + /// * `path` - Input file path. + /// + /// # Returns + /// + /// A [`GridArrayReader`] instance on success, or an error if reading fails. + pub fn from_file>(path: P) -> Result> { + let file = File::open(path)?; + let buf_reader = BufReader::new(file); + let mut decoder = FrameDecoder::new(buf_reader); + + let mut data = Vec::new(); + decoder.read_to_end(&mut data)?; + + let mut cursor = std::io::Cursor::new(&data); + + // Read metadata + let metadata_size: u64 = bincode::deserialize_from(&mut cursor)?; + let mut metadata_bytes = vec![0u8; metadata_size as usize]; + cursor.read_exact(&mut metadata_bytes)?; + let metadata: MetaData = bincode::deserialize(&metadata_bytes)?; + let shared_metadata = Arc::new(metadata); + let count: u64 = bincode::deserialize_from(&mut cursor)?; + + // Read offset table size (but don't skip it!) + let _offset_table_size: u64 = bincode::deserialize_from(&mut cursor)?; + + // Read the actual offsets + let mut offsets = Vec::with_capacity(count as usize); + for _ in 0..count { + let offset: u64 = bincode::deserialize_from(&mut cursor)?; + offsets.push(offset); + } + + let data_start = cursor.position(); + + Ok(Self { + data, + metadata: shared_metadata, + offsets, + count, + data_start, + }) + } + + /// Returns the number of grid arrays in the collection. + pub fn len(&self) -> usize { + self.count as usize + } + + /// Returns true if the collection is empty. + pub fn is_empty(&self) -> bool { + self.count == 0 + } + + /// Returns a reference to the shared metadata. + pub fn metadata(&self) -> &Arc { + &self.metadata + } + + /// Loads a specific [`GridArrayWithMetadata`] by index. + /// + /// # Arguments + /// + /// * `index` - The index of the grid to load. + /// + /// # Returns + /// + /// The requested [`GridArrayWithMetadata`] on success, or an error if the index is out + /// of bounds or reading fails. + pub fn load_grid( + &self, + index: usize, + ) -> Result> { + if index >= self.count as usize { + return Err(format!( + "Index {} out of bounds for collection of size {}", + index, self.count + ) + .into()); + } + + let offset = self.data_start + self.offsets[index]; + let mut cursor = std::io::Cursor::new(&self.data); + cursor.set_position(offset); + let size: u64 = bincode::deserialize_from(&mut cursor)?; + + let mut grid_bytes = vec![0u8; size as usize]; + cursor.read_exact(&mut grid_bytes)?; + + let grid: GridArray = bincode::deserialize(&grid_bytes)?; + + Ok(GridArrayWithMetadata { + grid, + metadata: Arc::clone(&self.metadata), + }) + } +} + +/// Iterator for lazily reading [`GridArrayWithMetadata`] members from a compressed file. +/// +/// Useful for memory-efficient sequential processing of large PDF sets. +pub struct LazyGridArrayIterator { + cursor: std::io::Cursor>, + remaining: u64, + metadata: Arc, + buffer: Vec, +} + +impl LazyGridArrayIterator { + /// Creates a new lazy iterator from a reader. + /// + /// # Arguments + /// + /// * `reader` - Any type implementing [`Read`]. + /// + /// # Returns + /// + /// A [`LazyGridArrayIterator`] instance on success, or an error if reading fails. + pub fn new(reader: R) -> Result> { + let mut decoder = FrameDecoder::new(reader); + let mut decompressed = Vec::new(); + decoder.read_to_end(&mut decompressed)?; + + let mut cursor = std::io::Cursor::new(decompressed); + + let metadata_size: u64 = bincode::deserialize_from(&mut cursor)?; + let mut metadata_bytes = vec![0u8; metadata_size as usize]; + cursor.read_exact(&mut metadata_bytes)?; + let metadata: MetaData = bincode::deserialize(&metadata_bytes)?; + let shared_metadata = Arc::new(metadata); + + let count: u64 = bincode::deserialize_from(&mut cursor)?; + + // Read and skip the offset table + let offset_table_size: u64 = bincode::deserialize_from(&mut cursor)?; + let mut offset_table_bytes = vec![0u8; offset_table_size as usize]; + cursor.read_exact(&mut offset_table_bytes)?; + + Ok(Self { + cursor, + remaining: count, + metadata: shared_metadata, + buffer: Vec::new(), + }) + } + + /// Creates a new lazy iterator from a file path. + /// + /// # Arguments + /// + /// * `path` - Input file path. + /// + /// # Returns + /// + /// A [`LazyGridArrayIterator`] instance on success, or an error if reading fails. + pub fn from_file>(path: P) -> Result> { + let file = File::open(path)?; + let buf_reader = BufReader::new(file); + Self::new(buf_reader) + } + + /// Returns a reference to the shared metadata. + pub fn metadata(&self) -> &Arc { + &self.metadata + } +} + +impl Iterator for LazyGridArrayIterator { + type Item = Result>; + + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + + let result = (|| -> Result> { + // Read size + let size: u64 = bincode::deserialize_from(&mut self.cursor)?; + + // Read grid data + self.buffer.resize(size as usize, 0); + self.cursor.read_exact(&mut self.buffer)?; + + let grid: GridArray = bincode::deserialize(&self.buffer)?; + + Ok(GridArrayWithMetadata { + grid, + metadata: Arc::clone(&self.metadata), + }) + })(); + + self.remaining -= 1; + Some(result) + } + + fn size_hint(&self) -> (usize, Option) { + let remaining = self.remaining as usize; + (remaining, Some(remaining)) + } +} + +impl ExactSizeIterator for LazyGridArrayIterator {} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::Array1; + use tempfile::NamedTempFile; + + use crate::metadata::{InterpolatorType, MetaDataV1, SetType}; + + #[test] + fn test_collection_with_metadata() { + let metadata_v1 = MetaDataV1 { + set_desc: "Test PDF".into(), + set_index: 1, + num_members: 2, + x_min: 1e-5, + x_max: 1.0, + q_min: 1.0, + q_max: 1000.0, + flavors: vec![1, 2, 3], + format: "NeoPDF".into(), + alphas_q_values: vec![], + alphas_vals: vec![], + polarised: false, + set_type: SetType::SpaceLike, + interpolator_type: InterpolatorType::LogBicubic, + error_type: "replicas".into(), + hadron_pid: 2212, + git_version: String::new(), + code_version: String::new(), + flavor_scheme: String::new(), + order_qcd: 0, + alphas_order_qcd: 0, + m_w: 0.0, + m_z: 0.0, + m_up: 0.0, + m_down: 0.0, + m_strange: 0.0, + m_charm: 0.0, + m_bottom: 0.0, + m_top: 0.0, + alphas_type: String::new(), + number_flavors: 0, + }; + let metadata = MetaData::new_v1(metadata_v1); + + let test_grid = test_grid(); + let grids = vec![&test_grid, &test_grid]; + let temp_file = NamedTempFile::new().unwrap(); + let path = temp_file.path(); + + GridArrayCollection::compress(&grids, &metadata, path).unwrap(); + let extracted = GridArrayCollection::extract_metadata(path).unwrap(); + assert_eq!(metadata.set_desc, extracted.set_desc); + assert_eq!(metadata.set_index, extracted.set_index); + + let decompressed = GridArrayCollection::decompress(path).unwrap(); + assert_eq!(decompressed.len(), 2); + for g in &decompressed { + assert_eq!(g.metadata.set_desc, "Test PDF"); + assert_eq!(g.grid.pids, Array1::from(vec![1, 2, 3])); + } + + let g_iter = LazyGridArrayIterator::from_file(path).unwrap(); + assert_eq!(g_iter.metadata().set_index, 1); + assert_eq!(g_iter.count(), 2); + } + + fn test_grid() -> GridArray { + GridArray { + pids: Array1::from(vec![1, 2, 3]), + subgrids: vec![], + } + } +}