Skip to content
Open
6 changes: 1 addition & 5 deletions doc/specs/stdlib_specialmatrices.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Experimental

With the exception of `extended precision` and `quadruple precision`, all the types provided by `stdlib_specialmatrices` benefit from specialized kernels for matrix-vector products accessible via the common `spmv` interface.

- For `tridiagonal` matrices, the LAPACK `lagtm` backend is being used.
- For `tridiagonal` matrices, the backend is either LAPACK `lagtm` or the generalized routine `glagtm`, depending on the values and types of `alpha` and `beta`.

#### Syntax

Expand All @@ -110,10 +110,6 @@ With the exception of `extended precision` and `quadruple precision`, all the ty

- `op` (optional) : In-place operator identifier. Shall be a character(1) argument. It can have any of the following values: `N`: no transpose, `T`: transpose, `H`: hermitian or complex transpose.

@warning
Due to limitations of the underlying `lapack` driver, currently `alpha` and `beta` can only take one of the values `[-1, 0, 1]` for `tridiagonal` and `symtridiagonal` matrices. See `lagtm` for more details.
@endwarning

#### Examples

```fortran
Expand Down
1 change: 1 addition & 0 deletions example/specialmatrices/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
ADD_EXAMPLE(specialmatrices_dp_spmv)
ADD_EXAMPLE(specialmatrices_cdp_spmv)
ADD_EXAMPLE(tridiagonal_dp_type)
30 changes: 30 additions & 0 deletions example/specialmatrices/example_specialmatrices_cdp_spmv.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
program example_tridiagonal_matrix_cdp
use stdlib_linalg_constants, only: dp
use stdlib_specialmatrices, only: tridiagonal_cdp_type, tridiagonal, dense, spmv
implicit none

integer, parameter :: n = 5
type(tridiagonal_cdp_type) :: A
complex(dp) :: dl(n-1), dv(n), du(n-1)
complex(dp) :: x(n), y(n), y_dense(n)
integer :: i
complex(dp) :: alpha, beta

dl = [(cmplx(i,i, dp), i=1, n - 1)]
dv = [(cmplx(2*i,2*i, dp), i=1, n)]
du = [(cmplx(3*i,3*i, dp), i=1, n - 1)]

A = tridiagonal(dl, dv, du)

x = (1.0_dp, 0.0_dp)
y = (3.0_dp, -7.0_dp)
y_dense = (0.0_dp, 0.0_dp)
alpha = cmplx(2.0_dp, 3.0_dp)
beta = cmplx(-1.0_dp, 5.0_dp)

y_dense = alpha * matmul(dense(A), x) + beta * y
call spmv(A, x, y, alpha, beta)

print *, 'dense :', y_dense
print *, 'Tridiagonal :', y
end program example_tridiagonal_matrix_cdp
3 changes: 2 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ if (NOT STDLIB_NO_BITSET)
endif()
add_subdirectory(blas)
add_subdirectory(lapack)
add_subdirectory(lapack_extended)

set(fppFiles
stdlib_ascii.fypp
Expand Down Expand Up @@ -124,4 +125,4 @@ configure_stdlib_target(${PROJECT_NAME} f90Files fppFiles cppFiles)

target_link_libraries(${PROJECT_NAME} PUBLIC
$<$<NOT:$<BOOL:${STDLIB_NO_BITSET}>>:bitsets>
blas lapack)
blas lapack lapack_extended)
10 changes: 10 additions & 0 deletions src/lapack_extended/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
set(lapack_extended_fppFiles
../stdlib_kinds.fypp
stdlib_lapack_extended_base.fypp
stdlib_lapack_extended.fypp
)
set(lapack_extended_cppFiles
../stdlib_linalg_constants.fypp
)

configure_stdlib_target(lapack_extended "" lapack_extended_fppFiles lapack_extended_cppFiles)
85 changes: 85 additions & 0 deletions src/lapack_extended/stdlib_lapack_extended.fypp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#:include "common.fypp"
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
#:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES

submodule(stdlib_lapack_extended_base) stdlib_lapack_extended
implicit none
contains
#:for ik,it,ii in LINALG_INT_KINDS_TYPES
#:for k1,t1,s1 in KINDS_TYPES
pure module subroutine stdlib${ii}$_glagtm_${s1}$(trans, n, nrhs, alpha, dl, d, du, x, ldx, beta, b, ldb)
character, intent(in) :: trans
integer(${ik}$), intent(in) :: ldb, ldx, n, nrhs
${t1}$, intent(in) :: alpha, beta
${t1}$, intent(inout) :: b(ldb,*)
${t1}$, intent(in) :: d(*), dl(*), du(*), x(ldx,*)

! Internal variables.
integer(${ik}$) :: i, j
${t1}$ :: temp
if(n == 0) then
return
endif
if(beta == 0.0_${k1}$) then
b(1:n, 1:nrhs) = 0.0_${k1}$
else
b(1:n, 1:nrhs) = beta * b(1:n, 1:nrhs)
end if

if(trans == 'N') then
do j = 1, nrhs
if(n == 1_${ik}$) then
temp = d(1_${ik}$) * x(1_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
else
temp = d(1_${ik}$) * x(1_${ik}$, j) + du(1_${ik}$) * x(2_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
do i = 2, n - 1
temp = dl(i - 1) * x(i - 1, j) + d(i) * x(i, j) + du(i) * x(i + 1, j)
b(i, j) = b(i, j) + alpha * temp
end do
temp = dl(n - 1) * x(n - 1, j) + d(n) * x(n, j)
b(n, j) = b(n, j) + alpha * temp
end if
end do
#:if t1.startswith('complex')
else if(trans == 'C') then
do j = 1, nrhs
if(n == 1_${ik}$) then
temp = conjg(d(1_${ik}$)) * x(1_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
else
temp = conjg(d(1_${ik}$)) * x(1_${ik}$, j) + conjg(dl(1_${ik}$)) * x(2_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
do i = 2, n - 1
temp = conjg(du(i - 1)) * x(i - 1, j) + conjg(d(i)) * x(i, j) + conjg(dl(i)) * x(i + 1, j)
b(i, j) = b(i, j) + alpha * temp
end do
temp = conjg(du(n - 1)) * x(n - 1, j) + conjg(d(n)) * x(n, j)
b(n, j) = b(n, j) + alpha * temp
end if
end do
#:endif
else
do j = 1, nrhs
if(n == 1_${ik}$) then
temp = d(1_${ik}$) * x(1_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
else
temp = d(1_${ik}$) * x(1_${ik}$, j) + dl(1_${ik}$) * x(2_${ik}$, j)
b(1_${ik}$, j) = b(1_${ik}$, j) + alpha * temp
do i = 2, n - 1
temp = du(i - 1) * x(i - 1, j) + d(i) * x(i, j) + dl(i) * x(i + 1, j)
b(i, j) = b(i, j) + alpha * temp
end do
temp = du(n - 1) * x(n - 1, j) + d(n) * x(n, j)
b(n, j) = b(n, j) + alpha * temp
end if
end do
end if
end subroutine stdlib${ii}$_glagtm_${s1}$
#:endfor
#:endfor

end submodule
22 changes: 22 additions & 0 deletions src/lapack_extended/stdlib_lapack_extended_base.fypp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#:include "common.fypp"
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
#:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES
module stdlib_lapack_extended_base
use stdlib_linalg_constants
implicit none

interface glagtm
#:for ik,it,ii in LINALG_INT_KINDS_TYPES
#:for k1,t1,s1 in KINDS_TYPES
pure module subroutine stdlib${ii}$_glagtm_${s1}$(trans, n, nrhs, alpha, dl, d, du, x, ldx, beta, b, ldb)
character, intent(in) :: trans
integer(${ik}$), intent(in) :: ldb, ldx, n, nrhs
${t1}$, intent(in) :: alpha, beta
${t1}$, intent(inout) :: b(ldb,*)
${t1}$, intent(in) :: d(*), dl(*), du(*), x(ldx,*)
end subroutine stdlib${ii}$_glagtm_${s1}$
#:endfor
#:endfor
end interface
end module
9 changes: 5 additions & 4 deletions src/stdlib_specialmatrices.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module stdlib_specialmatrices
use stdlib_constants
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
use stdlib_lapack_extended_base
implicit none
private
public :: tridiagonal
Expand Down Expand Up @@ -99,7 +100,7 @@ module stdlib_specialmatrices
!! Matrix dimension.
type(tridiagonal_${s1}$_type) :: A
!! Corresponding Tridiagonal matrix.
end function
end function

module function initialize_tridiagonal_impure_${s1}$(dl, dv, du, err) result(A)
!! Construct a `tridiagonal` matrix from the rank-1 arrays
Expand All @@ -122,7 +123,7 @@ module stdlib_specialmatrices
!! Error handling.
type(tridiagonal_${s1}$_type) :: A
!! Corresponding Tridiagonal matrix.
end function
end function
#:endfor
end interface

Expand All @@ -145,8 +146,8 @@ module stdlib_specialmatrices
type(tridiagonal_${s1}$_type), intent(in) :: A
${t1}$, intent(in), contiguous, target :: x${ranksuffix(rank)}$
${t1}$, intent(inout), contiguous, target :: y${ranksuffix(rank)}$
real(${k1}$), intent(in), optional :: alpha
real(${k1}$), intent(in), optional :: beta
${t1}$, intent(in), optional :: alpha
${t1}$, intent(in), optional :: beta
character(1), intent(in), optional :: op
end subroutine
#:endfor
Expand Down
34 changes: 29 additions & 5 deletions src/stdlib_specialmatrices_tridiagonal.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,18 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
type(tridiagonal_${s1}$_type), intent(in) :: A
${t1}$, intent(in), contiguous, target :: x${ranksuffix(rank)}$
${t1}$, intent(inout), contiguous, target :: y${ranksuffix(rank)}$
real(${k1}$), intent(in), optional :: alpha
real(${k1}$), intent(in), optional :: beta
${t1}$, intent(in), optional :: alpha
${t1}$, intent(in), optional :: beta
character(1), intent(in), optional :: op

! Internal variables.
real(${k1}$) :: alpha_, beta_
${t1}$ :: alpha_, beta_
integer(ilp) :: n, nrhs, ldx, ldy
character(1) :: op_
#:if t1.startswith('real')
logical :: is_alpha_special, is_beta_special
#:endif

#:if rank == 1
${t1}$, pointer :: xmat(:, :), ymat(:, :)
#:endif
Expand All @@ -171,6 +175,10 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
alpha_ = 1.0_${k1}$ ; if (present(alpha)) alpha_ = alpha
beta_ = 0.0_${k1}$ ; if (present(beta)) beta_ = beta
op_ = "N" ; if (present(op)) op_ = op
#:if t1.startswith('real')
is_alpha_special = (alpha_ == 1.0_${k1}$ .or. alpha_ == 0.0_${k1}$ .or. alpha_ == -1.0_${k1}$)
is_beta_special = (beta_ == 1.0_${k1}$ .or. beta_ == 0.0_${k1}$ .or. beta_ == -1.0_${k1}$)
#:endif

! Prepare Lapack arguments.
n = A%n ; ldx = n ; ldy = n ;
Expand All @@ -179,9 +187,25 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
#:if rank == 1
! Pointer trick.
xmat(1:n, 1:nrhs) => x ; ymat(1:n, 1:nrhs) => y
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
#:if t1.startswith('complex')
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
#:else
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
if(is_alpha_special .and. is_beta_special) then
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
else
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, xmat, ldx, beta_, ymat, ldy)
end if
#:endif
#:else
#:if t1.startswith('complex')
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
#:else
if(is_alpha_special .and. is_beta_special) then
call lagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
else
call glagtm(op_, n, nrhs, alpha_, A%dl, A%dv, A%du, x, ldx, beta_, y, ldy)
end if
#:endif
#:endif
end subroutine
#:endfor
Expand Down
33 changes: 33 additions & 0 deletions test/linalg/test_linalg_specialmatrices.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,39 @@ contains
if (allocated(error)) return
end do
end do

! Test y = A @ x for random values of alpha and beta
y1 = 0.0_wp
call random_number(alpha)
call random_number(beta)
call random_number(y2)
y1 = alpha * matmul(Amat, x) + beta * y2
call spmv(A, x, y2, alpha=alpha, beta=beta)
call check(error, all_close(y1, y2), .true.)
if (allocated(error)) return

! Test y = A.T @ x for random values of alpha and beta
y1 = 0.0_wp
call random_number(alpha)
call random_number(beta)
call random_number(y2)
y1 = alpha * matmul(transpose(Amat), x) + beta * y2
call spmv(A, x, y2, alpha=alpha, beta=beta, op="T")
call check(error, all_close(y1, y2), .true.)
if (allocated(error)) return

#:if t1.startswith('complex')
! Test y = A.H @ x for random values of alpha and beta
y1 = 0.0_wp
call random_number(alpha)
call random_number(beta)
call random_number(y2)
y1 = alpha * matmul(transpose(conjg((Amat))), x) + beta * y2
call spmv(A, x, y2, alpha=alpha, beta=beta, op="H")
call check(error, all_close(y1, y2), .true.)
if (allocated(error)) return
#:endif

end block
#:endfor
end subroutine
Expand Down
Loading