From 373a18a977fc8ed65d45390308a1b53779e7be4a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 26 Jul 2020 02:26:45 +0900 Subject: [PATCH 1/3] Impl tridiagonal by LAPACK --- lax/src/tridiagonal.rs | 52 ++++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs index 4eb8ff13..3dabd0c6 100644 --- a/lax/src/tridiagonal.rs +++ b/lax/src/tridiagonal.rs @@ -143,7 +143,13 @@ pub trait Tridiagonal_: Scalar + Sized { } macro_rules! impl_tridiagonal { - ($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + (@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork); + }; + (@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => { + impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, ); + }; + (@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => { impl Tridiagonal_ for $scalar { unsafe fn lu_tridiagonal( mut a: Tridiagonal, @@ -153,8 +159,11 @@ macro_rules! impl_tridiagonal { let mut ipiv = vec![0; n as usize]; // We have to calc one-norm before LU factorization let a_opnorm_one = a.opnorm_one(); - $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv) - .as_lapack_result()?; + let mut info = 0; + $gttrf( + n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info, + ); + info.as_lapack_result()?; Ok(LUFactorizedTridiagonal { a, du2, @@ -166,7 +175,12 @@ macro_rules! impl_tridiagonal { unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { let (n, _) = lu.a.l.size(); let ipiv = &lu.ipiv; + let mut work = vec![Self::zero(); 2 * n as usize]; + $( + let mut $iwork = vec![0; n as usize]; + )* let mut rcond = Self::Real::zero(); + let mut info = 0; $gtcon( NormType::One as u8, n, @@ -177,8 +191,11 @@ macro_rules! impl_tridiagonal { ipiv, lu.a_opnorm_one, &mut rcond, - ) - .as_lapack_result()?; + &mut work, + $(&mut $iwork,)* + &mut info, + ); + info.as_lapack_result()?; Ok(rcond) } @@ -192,27 +209,18 @@ macro_rules! impl_tridiagonal { let (_, nrhs) = bl.size(); let ipiv = &lu.ipiv; let ldb = bl.lda(); + let mut info = 0; $gttrs( - lu.a.l.lapacke_layout(), - t as u8, - n, - nrhs, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - b, - ldb, - ) - .as_lapack_result()?; + t as u8, n, nrhs, &lu.a.dl, &lu.a.d, &lu.a.du, &lu.du2, ipiv, b, ldb, &mut info, + ); + info.as_lapack_result()?; Ok(()) } } }; } // impl_tridiagonal! -impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs); -impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs); -impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs); -impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs); +impl_tridiagonal!(@real, f64, lapack::dgttrf, lapack::dgtcon, lapack::dgttrs); +impl_tridiagonal!(@real, f32, lapack::sgttrf, lapack::sgtcon, lapack::sgttrs); +impl_tridiagonal!(@complex, c64, lapack::zgttrf, lapack::zgtcon, lapack::zgttrs); +impl_tridiagonal!(@complex, c32, lapack::cgttrf, lapack::cgtcon, lapack::cgttrs); From 46b0dcd8097dfd1eaec1793943b39543233b174c Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 26 Jul 2020 17:51:23 +0900 Subject: [PATCH 2/3] Transpose if C-contiguous --- lax/src/tridiagonal.rs | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs index 3dabd0c6..e995e6f5 100644 --- a/lax/src/tridiagonal.rs +++ b/lax/src/tridiagonal.rs @@ -2,7 +2,7 @@ //! for tridiagonal matrix use super::*; -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; use num_traits::Zero; use std::ops::{Index, IndexMut}; @@ -201,19 +201,40 @@ macro_rules! impl_tridiagonal { unsafe fn solve_tridiagonal( lu: &LUFactorizedTridiagonal, - bl: MatrixLayout, + b_layout: MatrixLayout, t: Transpose, b: &mut [Self], ) -> Result<()> { let (n, _) = lu.a.l.size(); - let (_, nrhs) = bl.size(); let ipiv = &lu.ipiv; - let ldb = bl.lda(); + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match b_layout { + MatrixLayout::C { .. } => { + b_t = Some(vec![Self::zero(); b.len()]); + transpose(b_layout, b, b_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => b_layout, + }; + let (ldb, nrhs) = b_layout.size(); let mut info = 0; $gttrs( - t as u8, n, nrhs, &lu.a.dl, &lu.a.d, &lu.a.du, &lu.du2, ipiv, b, ldb, &mut info, + t as u8, + n, + nrhs, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + ldb, + &mut info, ); info.as_lapack_result()?; + if let Some(b_t) = b_t { + transpose(b_layout, &b_t, b); + } Ok(()) } } From 2a6154f10e90326bd28dfdcdfa3684c9aa5a204f Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sun, 26 Jul 2020 18:07:02 +0900 Subject: [PATCH 3/3] Drop unsafe --- lax/src/tridiagonal.rs | 76 +++++++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/lax/src/tridiagonal.rs b/lax/src/tridiagonal.rs index e995e6f5..ea5bb119 100644 --- a/lax/src/tridiagonal.rs +++ b/lax/src/tridiagonal.rs @@ -130,11 +130,11 @@ impl IndexMut<[i32; 2]> for Tridiagonal { pub trait Tridiagonal_: Scalar + Sized { /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using /// partial pivoting with row interchanges. - unsafe fn lu_tridiagonal(a: Tridiagonal) -> Result>; + fn lu_tridiagonal(a: Tridiagonal) -> Result>; - unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result; - unsafe fn solve_tridiagonal( + fn solve_tridiagonal( lu: &LUFactorizedTridiagonal, bl: MatrixLayout, t: Transpose, @@ -151,18 +151,14 @@ macro_rules! impl_tridiagonal { }; (@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => { impl Tridiagonal_ for $scalar { - unsafe fn lu_tridiagonal( - mut a: Tridiagonal, - ) -> Result> { + fn lu_tridiagonal(mut a: Tridiagonal) -> Result> { let (n, _) = a.l.size(); let mut du2 = vec![Zero::zero(); (n - 2) as usize]; let mut ipiv = vec![0; n as usize]; // We have to calc one-norm before LU factorization let a_opnorm_one = a.opnorm_one(); let mut info = 0; - $gttrf( - n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info, - ); + unsafe { $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info,) }; info.as_lapack_result()?; Ok(LUFactorizedTridiagonal { a, @@ -172,7 +168,7 @@ macro_rules! impl_tridiagonal { }) } - unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { + fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal) -> Result { let (n, _) = lu.a.l.size(); let ipiv = &lu.ipiv; let mut work = vec![Self::zero(); 2 * n as usize]; @@ -181,25 +177,27 @@ macro_rules! impl_tridiagonal { )* let mut rcond = Self::Real::zero(); let mut info = 0; - $gtcon( - NormType::One as u8, - n, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - lu.a_opnorm_one, - &mut rcond, - &mut work, - $(&mut $iwork,)* - &mut info, - ); + unsafe { + $gtcon( + NormType::One as u8, + n, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + lu.a_opnorm_one, + &mut rcond, + &mut work, + $(&mut $iwork,)* + &mut info, + ); + } info.as_lapack_result()?; Ok(rcond) } - unsafe fn solve_tridiagonal( + fn solve_tridiagonal( lu: &LUFactorizedTridiagonal, b_layout: MatrixLayout, t: Transpose, @@ -218,19 +216,21 @@ macro_rules! impl_tridiagonal { }; let (ldb, nrhs) = b_layout.size(); let mut info = 0; - $gttrs( - t as u8, - n, - nrhs, - &lu.a.dl, - &lu.a.d, - &lu.a.du, - &lu.du2, - ipiv, - b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), - ldb, - &mut info, - ); + unsafe { + $gttrs( + t as u8, + n, + nrhs, + &lu.a.dl, + &lu.a.d, + &lu.a.du, + &lu.du2, + ipiv, + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + ldb, + &mut info, + ); + } info.as_lapack_result()?; if let Some(b_t) = b_t { transpose(b_layout, &b_t, b);