Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 130 additions & 56 deletions include/alp/algorithms/householder_tridiag.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <sstream>

#include <alp.hpp>
#include <graphblas/utils/iscomplex.hpp> // use from grb
#include "../tests/utils/print_alp_containers.hpp"

namespace alp {

Expand All @@ -40,31 +42,59 @@ namespace alp {
*
*/
template<
typename D = double,
typename D,
typename SymmOrHermType,
typename SymmOrHermTridiagonalType,
typename OrthogonalType,
class Ring = Semiring< operators::add< D >, operators::mul< D >, identities::zero, identities::one >,
class Minus = operators::subtract< D >,
class Divide = operators::divide< D > >
class Divide = operators::divide< D >
>
RC householder_tridiag(
Matrix< D, structures::Orthogonal, Dense > &Q,
Matrix< D, structures::SymmetricTridiagonal, Dense > &T, // Need to be add this once alp -> alp is done
const Matrix< D, structures::Symmetric, Dense > &H,
Matrix< D, OrthogonalType, Dense > &Q,
Matrix< D, SymmOrHermTridiagonalType, Dense > &T,
Matrix< D, SymmOrHermType, Dense > &H,
const Ring & ring = Ring(),
const Minus & minus = Minus(),
const Divide & divide = Divide() ) {

RC rc = SUCCESS;

const Scalar< D > zero( ring.template getZero< D >() );
const Scalar< D > one( ring.template getOne< D >() );
const size_t n = nrows( H );

// Q = identity( n )
rc = set( Q , structures::constant::I( n ) );
rc = alp::set( Q, zero );
auto Qdiag = alp::get_view< alp::view::diagonal >( Q );
rc = rc ? rc : alp::set( Qdiag, one );
if( rc != SUCCESS ) {
std::cerr << " set( Q, I ) failed\n";
return rc;
}

// Out of place specification of the computation
Matrix< D, structures::Symmetric, Dense > RR( n );
Matrix< D, SymmOrHermType, Dense > RR( n );

rc = set( RR, H );
if( rc != SUCCESS ) {
std::cerr << " set( RR, H ) failed\n";
return rc;
}
#ifdef DEBUG
print_matrix( " << RR >> ", RR );
#endif

// a temporary for storing the mxm result
Matrix< D, OrthogonalType, Dense > Qtmp( n, n );

for( size_t k = 0; k < n - 2; ++k ) {
#ifdef DEBUG
std::string matname(" << RR(");
matname = matname + std::to_string(k);
matname = matname + std::string( ") >> ");
print_matrix( matname , RR );
#endif

const size_t m = n - k - 1;

Expand All @@ -73,82 +103,126 @@ namespace alp {
// alpha = norm( v ) * v[ 0 ] / norm( v[ 0 ] )
// v = v - alpha * e1
// v = v / norm ( v )
Vector< D, structures::General, Dense > v;
rc = set( v, get_view( RR, utils::range( k + 1, n ), k ) );

Scalar< D > alpha;
auto v_view = get_view( RR, k, utils::range( k + 1, n ) );
Vector< D, structures::General, Dense > v( n - ( k + 1 ) );
rc = set( v, v_view );
if( rc != SUCCESS ) {
std::cerr << " set( v, view ) failed\n";
return rc;
}

Scalar< D > alpha( zero );
rc = norm2( alpha, v, ring );
if( rc != SUCCESS ) {
std::cerr << " norm2( alpha, v, ring ) failed\n";
return rc;
}

rc = eWiseLambda(
[ &v, &alpha ]( const size_t i ) {
if ( i == 0 ) {
Scalar< D > norm_v0( std::abs( v[ i ] ) );
foldl(alpha, v [ i ], ring.getMultiplicativeOperator() );
foldl(alpha, norm_v0, divide );
foldl(v [ i ], alpha, minus );
}
},
v );

Scalar< D > norm_v;
[ &alpha, &ring, &divide, &minus ]( const size_t i, D &val ) {
if ( i == 0 ) {
Scalar< D > norm_v0( std::abs( val ) );
Scalar< D > val_scalar( val );
foldl( alpha, val_scalar, ring.getMultiplicativeOperator() );
foldl( alpha, norm_v0, divide );
foldl( val_scalar, alpha, minus );
val = *val_scalar;
}
},
v
);
if( rc != SUCCESS ) {
std::cerr << " eWiseLambda( lambda, v ) failed\n";
return rc;
}

Scalar< D > norm_v( zero );
rc = norm2( norm_v, v, ring );
if( rc != SUCCESS ) {
std::cerr << " norm2( norm_v, v, ring ) failed\n";
return rc;
}

rc = foldl(v, norm_v, divide );
#ifdef DEBUG
print_vector( " v = ", v );
#endif
// ===== End Computing v =====


// ===== Calculate reflector Qk =====
// Q_k = identity( n )
Matrix< D, structures::Symmetric, Dense > Qk( m );
rc = set( Qk, structures::constant::I( m ) );
Matrix< D, SymmOrHermType, Dense > Qk( n );
rc = alp::set( Qk, zero );
auto Qk_diag = alp::get_view< alp::view::diagonal >( Qk );
rc = rc ? rc : alp::set( Qk_diag, one );

Matrix< D, structures::Symmetric, Dense > vvt( m );
rc = set(vvt, zero );
// vvt = v * v^T
rc = outer( vvt, v, v, ring );
// this part can be rewriten without temp matrix using functors
Matrix< D, SymmOrHermType, Dense > vvt( m );

rc = rc ? rc : set( vvt, outer( v, ring.getMultiplicativeOperator() ) );
// vvt = 2 * vvt
rc = foldr( Scalar< D >( 2 ), vvt, ring.getMultiplicativeOperator() );
rc = rc ? rc : foldr( Scalar< D >( 2 ), vvt, ring.getMultiplicativeOperator() );


#ifdef DEBUG
print_matrix( " vvt ", vvt );
#endif

// Qk = Qk - vvt ( expanded: I - 2 * vvt )
rc = foldl( Qk, vvt, minus );
auto Qk_view = get_view< SymmOrHermType >( Qk, utils::range( k + 1, n ), utils::range( k + 1, n ) );
if ( grb::utils::is_complex< D >::value ) {
rc = rc ? rc : foldl( Qk_view, alp::get_view< alp::view::transpose >( vvt ), minus );
} else {
rc = rc ? rc : foldl( Qk_view, vvt, minus );
}

#ifdef DEBUG
print_matrix( " << Qk >> ", Qk );
#endif
// ===== End of Calculate reflector Qk ====

// ===== Update R =====
// Rk = Qk * Rk * Qk^T

// get a view over RR (temporary of R)
auto Rk = get_view( RR, range( k + 1, n ), range( k + 1, n ) );

// QkRk = Qk * Rk
Matrix< D, structures::Square, Dense > QkRk( m );
rc = set( QkRk, zero );
rc = mxm( QkRk, Qk, Rk, ring );

// Rk = QkRk * QkT
rc = set( Rk, zero );
rc = mxm( Rk, QkRk, Qk, ring );
// Rk = Qk * Rk * Qk

// RRQk = RR * Qk
Matrix< D, structures::Square, Dense > RRQk( n );
rc = rc ? rc : set( RRQk, zero );
rc = rc ? rc : mxm( RRQk, RR, Qk, ring );
if( rc != SUCCESS ) {
std::cerr << " mxm( RRQk, RR, Qk, ring ); failed\n";
return rc;
}
#ifdef DEBUG
print_matrix( " << RR x Qk = >> ", RRQk );
#endif
// RR = Qk * RRQk
rc = rc ? rc : set( RR, zero );
rc = rc ? rc : mxm( RR, Qk, RRQk, ring );

#ifdef DEBUG
print_matrix( " << RR( updated ) >> ", RR );
#endif
// ===== End of Update R =====

// ===== Update Q =====
// Q = Q * conjugate( QkT )
// a temporary for storing the mxm result
Matrix< D, structures::Orthogonal, Dense > Qtmp( m, m );
// a view over smaller portion of Q
auto Qprim = get_view( Q, range( k + 1, n ), range( k + 1, n ) );

// Qtmp = Qprim * QkT
rc = set( Qtmp, zero );
rc = mxm( Qtmp, Qprim, Qk, ring );

// Qprim = Qtmp
rc = set( Qprim, Qtmp );
// Q = Q * Qk

// Qtmp = Q * Qk
rc = rc ? rc : set( Qtmp, zero );
rc = rc ? rc : mxm( Qtmp, Q, Qk, ring );

// Q = Qtmp
rc = rc ? rc : set( Q, Qtmp );
#ifdef DEBUG
print_matrix( " << Q updated >> ", Q );
#endif
// ===== End of Update Q =====
}

// T = RR
rc = set( T, get_view< structures::SymmetricTridiagonal > ( RR ) );

rc = rc ? rc : set( T, get_view< SymmOrHermTridiagonalType > ( RR ) );
return rc;
}
} // namespace algorithms
Expand Down
28 changes: 18 additions & 10 deletions include/alp/reference/blas1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "vector.hpp"
#include "blas0.hpp"
#include "blas2.hpp"
#include <graphblas/utils/iscomplex.hpp> // use from grb

#ifndef NO_CAST_ASSERT
#define NO_CAST_ASSERT( x, y, z ) \
Expand Down Expand Up @@ -2508,7 +2509,11 @@ namespace alp {
std::function< void( typename AddMonoid::D3 &, const size_t, const size_t ) > data_lambda =
[ &x, &y, &anyOp ]( typename AddMonoid::D3 &result, const size_t i, const size_t j ) {
(void) j;
internal::apply( result, x[ i ], y[ i ], anyOp );
internal::apply(
result, x[ i ],
grb::utils::is_complex< InputType2 >::conjugate( y[ i ] ),
anyOp
);
};

std::function< bool() > init_lambda =
Expand Down Expand Up @@ -2634,15 +2639,15 @@ namespace alp {
void >::type * const = NULL
) {
Scalar< IOType, structures::General, backend > res( x );
RC rc = alp::dot< descr >( x,
RC rc = alp::dot< descr >( res,
left, right,
ring.getAdditiveMonoid(),
ring.getMultiplicativeOperator()
);
if( rc != SUCCESS ) {
return rc;
}
/** \internal \todo extract res.value into x */
x = *res;
return SUCCESS;
}

Expand Down Expand Up @@ -2896,16 +2901,18 @@ namespace alp {
class Ring,
Backend backend
>
RC norm2( Scalar< OutputType, OutputStructure, backend > &x,
RC norm2(
Scalar< OutputType, OutputStructure, backend > &x,
const Vector< InputType, InputStructure, Density::Dense, InputView, InputImfR, InputImfC, backend > &y,
const Ring &ring = Ring(),
const typename std::enable_if<
std::is_floating_point< OutputType >::value,
void >::type * const = NULL
std::is_floating_point< OutputType >::value || grb::utils::is_complex< OutputType >::value,
void
>::type * const = NULL
) {
RC ret = alp::dot< descr >( x, y, y, ring );
if( ret == SUCCESS ) {
x = sqrt( x );
*x = sqrt( *x );
}
return ret;
}
Expand All @@ -2923,15 +2930,16 @@ namespace alp {
const Vector< InputType, InputStructure, Density::Dense, InputView, InputImfR, InputImfC, backend > &y,
const Ring &ring = Ring(),
const typename std::enable_if<
std::is_floating_point< OutputType >::value,
void >::type * const = nullptr
std::is_floating_point< OutputType >::value || grb::utils::is_complex< OutputType >::value,
void
>::type * const = nullptr
) {
Scalar< OutputType, structures::General, reference > res( x );
RC rc = norm2( res, y, ring );
if( rc != SUCCESS ) {
return rc;
}
/** \internal \todo extract res.value into x */
x = *res;
return SUCCESS;
}

Expand Down
Loading