Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions source/source_base/module_container/base/third_party/lapack.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/**
* @file lapack.h
* @brief This is a direct wrapper of some LAPACK routines.
* \b Column-Major version.
* Direct wrapping of standard LAPACK routines. (Column-Major, fortran style)
*
* @warning For Row-major version, please refer to \c source/source_base/module_external/lapack_connector.h.
*
* @note
* Some slight modification are made to fit the C++ style for overloading purpose.
* You can find some function with different parameter list than the original LAPACK routine.
* And some of these parameters are not referred in the function body. They are included just to
* ensure the same parameter list for overloaded functions with a uniform name.
*/

#ifndef BASE_THIRD_PARTY_LAPACK_H_
#define BASE_THIRD_PARTY_LAPACK_H_

Expand All @@ -10,6 +25,10 @@
#include <base/third_party/hipsolver.h>
#endif

/// This is a wrapper of some LAPACK routines.
/// Direct wrapping of standard LAPACK routines. (column major, fortran style)
/// with some slight modification to fit the C++ style for overloading purpose.

//Naming convention of lapack subroutines : ammxxx, where
//"a" specifies the data type:
// - d stands for double
Expand Down Expand Up @@ -46,6 +65,27 @@ void chegvd_(const int* itype, const char* jobz, const char* uplo, const int* n,
std::complex<float>* work, int* lwork, float* rwork, int* lrwork,
int* iwork, int* liwork, int* info);

void ssygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
const int* n, float* A, const int* lda, float* B, const int* ldb,
const float* vl, const float* vu, const int* il, const int* iu,
const float* abstol, const int* m, float* w, float* Z, const int* ldz,
float* work, const int* lwork, int* iwork, int* ifail, int* info);
void dsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
const int* n, double* A, const int* lda, double* B, const int* ldb,
const double* vl, const double* vu, const int* il, const int* iu,
const double* abstol, const int* m, double* w, double* Z, const int* ldz,
double* work, const int* lwork, int* iwork, int* ifail, int* info);
void chegvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
const int* n, std::complex<float>* A, const int* lda, std::complex<float>* B, const int* ldb,
const float* vl, const float* vu, const int* il, const int* iu,
const float* abstol, const int* m, float* w, std::complex<float>* Z, const int* ldz,
std::complex<float>* work, const int* lwork, float* rwork, int* iwork, int* ifail, int* info);
void zhegvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
const int* n, std::complex<double>* A, const int* lda, std::complex<double>* B, const int* ldb,
const double* vl, const double* vu, const int* il, const int* iu,
const double* abstol, const int* m, double* w, std::complex<double>* Z, const int* ldz,
std::complex<double>* work, const int* lwork, double* rwork, int* iwork, int* ifail, int* info);

void zhegvd_(const int* itype, const char* jobz, const char* uplo, const int* n,
std::complex<double>* a, const int* lda,
const std::complex<double>* b, const int* ldb, double* w,
Expand Down Expand Up @@ -190,6 +230,68 @@ void hegvd(const int itype, const char jobz, const char uplo, const int n,
iwork, &liwork, &info);
}

// Note
// rwork is only needed for complex version
// and we include rwork in the function parameter list
// for simplicity of function overloading
// and unification of function parameter list
static inline
void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n,
float* a, const int lda, float* b, const int ldb,
const float vl, const float vu, const int il, const int iu, const float abstol,
const int m, float* w, float* z, const int ldz,
float* work, const int lwork, float* rwork, int* iwork, int* ifail, int& info)
{
ssygvx_(&itype, &jobz, &range, &uplo, &n,
a, &lda, b, &ldb,
&vl, &vu, &il, &iu,
&abstol, &m, w, z, &ldz,
work, &lwork, iwork, ifail, &info);
}

static inline
void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n,
double* a, const int lda, double* b, const int ldb,
const double vl, const double vu, const int il, const int iu, const double abstol,
const int m, double* w, double* z, const int ldz,
double* work, const int lwork, double* rwork, int* iwork, int* ifail, int& info)
{
dsygvx_(&itype, &jobz, &range, &uplo, &n,
a, &lda, b, &ldb,
&vl, &vu, &il, &iu,
&abstol, &m, w, z, &ldz,
work, &lwork, iwork, ifail, &info);
}

static inline
void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n,
std::complex<float>* a, const int lda, std::complex<float>* b, const int ldb,
const float vl, const float vu, const int il, const int iu, const float abstol,
const int m, float* w, std::complex<float>* z, const int ldz,
std::complex<float>* work, const int lwork, float* rwork, int* iwork, int* ifail, int& info)
{
chegvx_(&itype, &jobz, &range, &uplo, &n,
a, &lda, b, &ldb,
&vl, &vu, &il, &iu,
&abstol, &m, w, z, &ldz,
work, &lwork, rwork, iwork, ifail, &info);
}

static inline
void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n,
std::complex<double>* a, const int lda, std::complex<double>* b, const int ldb,
const double vl, const double vu, const int il, const int iu, const double abstol,
const int m, double* w, std::complex<double>* z, const int ldz,
std::complex<double>* work, const int lwork, double* rwork, int* iwork, int* ifail, int& info)
{
zhegvx_(&itype, &jobz, &range, &uplo, &n,
a, &lda, b, &ldb,
&vl, &vu, &il, &iu,
&abstol, &m, w, z, &ldz,
work, &lwork, rwork, iwork, ifail, &info);
}


// wrap function of fortran lapack routine zheevx.
static inline
void heevx( const int itype, const char jobz, const char range, const char uplo, const int n,
Expand Down
33 changes: 28 additions & 5 deletions source/source_base/module_external/lapack_connector.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,26 @@
#ifndef LAPACKCONNECTOR_HPP
#define LAPACKCONNECTOR_HPP
/**
* @file lapack_connector.h
*
* @brief This is a wrapper of some LAPACK routines.
* \b Row-Major version.
*
* @warning MAY BE DEPRECATED IN THE FUTURE.
* @warning For Column-major version, please refer to \c source/source_base/module_container/base/third_party/lapack.h.
*
* @note
* !!! Note that
* This wrapper is a <b>C++ style</b> wrapper of LAPACK routines,
* i.e., assuming that the input matrices are in \b row-major order.
* The data layout in C++ is row-major, C style,
* while the original LAPACK is column-major, fortran style.
* (ModuleBase::ComplexMatrix is in row-major order)
* The wrapper will do the data transformation between
* row-major and column-major order automatically.
*
*/

#ifndef LAPACK_CONNECTOR_HPP
#define LAPACK_CONNECTOR_HPP

#include <new>
#include <stdexcept>
Expand All @@ -11,8 +32,10 @@

//Naming convention of lapack subroutines : ammxxx, where
//"a" specifies the data type:
// - d stands for double
// - z stands for complex double
// - s stands for float
// - d stands for double
// - c stands for complex float
// - z stands for complex double
//"mm" specifies the type of matrix, for example:
// - he stands for hermitian
// - sy stands for symmetric
Expand Down Expand Up @@ -468,4 +491,4 @@ class LapackConnector
cherk_(&uplo_changed, &trans_changed, &n, &k, &alpha, A, &lda, &beta, C, &ldc);
}
};
#endif // LAPACKCONNECTOR_HPP
#endif // LAPACK_CONNECTOR_HPP
10 changes: 10 additions & 0 deletions source/source_base/module_external/lapack_wrapper.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
#ifndef LAPACK_HPP
#define LAPACK_HPP

/// This is a wrapper of some LAPACK routines.
/// Direct wrapping of standard LAPACK routines. (column major, fortran style)
/// including:
/// 1. hegvd: compute all the eigenvalues and eigenvectors of a generalized Hermitian-definite eigenproblem
/// 2. heevx: compute the first m eigenvalues and their corresponding eigenvectors of a generalized Hermitian-definite eigenproblem
/// 3. hegvx: compute the first m eigenvalues and their corresponding eigenvectors of a generalized Hermitian-definite eigenproblem
/// 4. hegv: compute all the eigenvalues and eigenvectors of a generalized Hermitian-definite eigenproblem


#include <iostream>
extern "C"
{
Expand Down
133 changes: 68 additions & 65 deletions source/source_hsolver/kernels/hegvd_op.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#include "source_hsolver/kernels/hegvd_op.h"
#include "source_base/module_container/base/third_party/lapack.h"

#include <algorithm>
#include <fstream>
#include <iostream>

namespace lapackConnector = container::lapackConnector; // see "source_base/module_container/base/third_party/lapack.h"

namespace hsolver
{
// hegvd and sygvd; dn for dense?
Expand Down Expand Up @@ -39,7 +42,7 @@ struct hegvd_op<T, base_device::DEVICE_CPU>
//===========================
// calculate all eigenvalues
//===========================
LapackWrapper::xhegvd(1,
lapackConnector::hegvd(1,
'V',
'U',
nstart,
Expand All @@ -58,7 +61,7 @@ struct hegvd_op<T, base_device::DEVICE_CPU>

if (info != 0)
{
std::cout << "Error: xhegvd failed, linear dependent basis functions\n"
std::cout << "Error: hegvd failed, linear dependent basis functions\n"
<< ", wrong initialization of wavefunction, or wavefunction information loss\n"
<< ", output overlap matrix scc.txt to check\n"
<< std::endl;
Expand All @@ -82,62 +85,62 @@ struct hegvd_op<T, base_device::DEVICE_CPU>
}
};

template <typename T>
struct hegv_op<T, base_device::DEVICE_CPU>
{
using Real = typename GetTypeReal<T>::type;
void operator()(const base_device::DEVICE_CPU* d,
const int nbase,
const int ldh,
const T* hcc,
T* scc,
Real* eigenvalue,
T* vcc)
{
for (int i = 0; i < nbase * ldh; i++)
{
vcc[i] = hcc[i];
}

int info = 0;

int lwork = 2 * nbase - 1;
T* work = new T[lwork];
Parallel_Reduce::ZEROS(work, lwork);

int lrwork = 3 * nbase - 2;
Real* rwork = new Real[lrwork];
Parallel_Reduce::ZEROS(rwork, lrwork);

//===========================
// calculate all eigenvalues
//===========================
LapackWrapper::xhegv(1, 'V', 'U', nbase, vcc, ldh, scc, ldh, eigenvalue, work, lwork, rwork, info);

if (info != 0)
{
std::cout << "Error: xhegv failed, linear dependent basis functions\n"
<< ", wrong initialization of wavefunction, or wavefunction information loss\n"
<< ", output overlap matrix scc.txt to check\n"
<< std::endl;
// print scc to file scc.txt
std::ofstream ofs("scc.txt");
for (int i = 0; i < nbase; i++)
{
for (int j = 0; j < nbase; j++)
{
ofs << scc[i * ldh + j] << " ";
}
ofs << std::endl;
}
ofs.close();
}
assert(0 == info);

delete[] work;
delete[] rwork;
}
};
// template <typename T>
// struct hegv_op<T, base_device::DEVICE_CPU>
// {
// using Real = typename GetTypeReal<T>::type;
// void operator()(const base_device::DEVICE_CPU* d,
// const int nbase,
// const int ldh,
// const T* hcc,
// T* scc,
// Real* eigenvalue,
// T* vcc)
// {
// for (int i = 0; i < nbase * ldh; i++)
// {
// vcc[i] = hcc[i];
// }

// int info = 0;

// int lwork = 2 * nbase - 1;
// T* work = new T[lwork];
// Parallel_Reduce::ZEROS(work, lwork);

// int lrwork = 3 * nbase - 2;
// Real* rwork = new Real[lrwork];
// Parallel_Reduce::ZEROS(rwork, lrwork);

// //===========================
// // calculate all eigenvalues
// //===========================
// LapackWrapper::xhegv(1, 'V', 'U', nbase, vcc, ldh, scc, ldh, eigenvalue, work, lwork, rwork, info);

// if (info != 0)
// {
// std::cout << "Error: xhegv failed, linear dependent basis functions\n"
// << ", wrong initialization of wavefunction, or wavefunction information loss\n"
// << ", output overlap matrix scc.txt to check\n"
// << std::endl;
// // print scc to file scc.txt
// std::ofstream ofs("scc.txt");
// for (int i = 0; i < nbase; i++)
// {
// for (int j = 0; j < nbase; j++)
// {
// ofs << scc[i * ldh + j] << " ";
// }
// ofs << std::endl;
// }
// ofs.close();
// }
// assert(0 == info);

// delete[] work;
// delete[] rwork;
// }
// };

// heevx and syevx
/**
Expand Down Expand Up @@ -174,7 +177,7 @@ struct heevx_op<T, base_device::DEVICE_CPU>

// When lwork = -1, the demension of work will be assumed
// Assume the denmension of work by output work[0]
LapackWrapper::xheevx(
lapackConnector::heevx(
1, // ITYPE = 1: A*x = (lambda)*B*x
'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors.
'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
Expand Down Expand Up @@ -208,7 +211,7 @@ struct heevx_op<T, base_device::DEVICE_CPU>
// V is the output of the function, the storage space is also (nstart * ldh), and the data size of valid V
// obtained by the zhegvx operation is (nstart * nstart) and stored in zux (internal to the function). When
// the function is output, the data of zux will be mapped to the corresponding position of V.
LapackWrapper::xheevx(
lapackConnector::heevx(
1, // ITYPE = 1: A*x = (lambda)*B*x
'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors.
'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
Expand Down Expand Up @@ -267,7 +270,7 @@ struct hegvx_op<T, base_device::DEVICE_CPU>
int* iwork = new int[5 * nbase];
int* ifail = new int[nbase];

LapackWrapper::xhegvx(
lapackConnector::hegvx(
1, // ITYPE = 1: A*x = (lambda)*B*x
'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors.
'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
Expand Down Expand Up @@ -297,7 +300,7 @@ struct hegvx_op<T, base_device::DEVICE_CPU>
delete[] work;
work = new T[lwork];

LapackWrapper::xhegvx(1,
lapackConnector::hegvx(1,
'V',
'I',
'U',
Expand Down Expand Up @@ -338,12 +341,12 @@ template struct heevx_op<std::complex<double>, base_device::DEVICE_CPU>;
template struct hegvx_op<std::complex<float>, base_device::DEVICE_CPU>;
template struct hegvx_op<std::complex<double>, base_device::DEVICE_CPU>;

template struct hegv_op<std::complex<float>, base_device::DEVICE_CPU>;
template struct hegv_op<std::complex<double>, base_device::DEVICE_CPU>;
// template struct hegv_op<std::complex<float>, base_device::DEVICE_CPU>;
// template struct hegv_op<std::complex<double>, base_device::DEVICE_CPU>;
#ifdef __LCAO
template struct hegvd_op<double, base_device::DEVICE_CPU>;
template struct heevx_op<double, base_device::DEVICE_CPU>;
template struct hegvx_op<double, base_device::DEVICE_CPU>;
template struct hegv_op<double, base_device::DEVICE_CPU>;
// template struct hegv_op<double, base_device::DEVICE_CPU>;
#endif
} // namespace hsolver
Loading
Loading