diff --git a/source/source_base/module_container/base/third_party/lapack.h b/source/source_base/module_container/base/third_party/lapack.h index c3a8af2027..0117734993 100644 --- a/source/source_base/module_container/base/third_party/lapack.h +++ b/source/source_base/module_container/base/third_party/lapack.h @@ -1,3 +1,18 @@ +/** + * @file lapack.h + * @brief This is a direct wrapper of some LAPACK routines. + * \b Column-Major version. + * Direct wrapping of standard LAPACK routines. (Column-Major, fortran style) + * + * @warning For Row-major version, please refer to \c source/source_base/module_external/lapack_connector.h. + * + * @note + * Some slight modification are made to fit the C++ style for overloading purpose. + * You can find some function with different parameter list than the original LAPACK routine. + * And some of these parameters are not referred in the function body. They are included just to + * ensure the same parameter list for overloaded functions with a uniform name. + */ + #ifndef BASE_THIRD_PARTY_LAPACK_H_ #define BASE_THIRD_PARTY_LAPACK_H_ @@ -10,6 +25,10 @@ #include #endif +/// This is a wrapper of some LAPACK routines. +/// Direct wrapping of standard LAPACK routines. (column major, fortran style) +/// with some slight modification to fit the C++ style for overloading purpose. + //Naming convention of lapack subroutines : ammxxx, where //"a" specifies the data type: // - d stands for double @@ -46,6 +65,27 @@ void chegvd_(const int* itype, const char* jobz, const char* uplo, const int* n, std::complex* work, int* lwork, float* rwork, int* lrwork, int* iwork, int* liwork, int* info); +void ssygvx_(const int* itype, const char* jobz, const char* range, const char* uplo, + const int* n, float* A, const int* lda, float* B, const int* ldb, + const float* vl, const float* vu, const int* il, const int* iu, + const float* abstol, const int* m, float* w, float* Z, const int* ldz, + float* work, const int* lwork, int* iwork, int* ifail, int* info); +void dsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo, + const int* n, double* A, const int* lda, double* B, const int* ldb, + const double* vl, const double* vu, const int* il, const int* iu, + const double* abstol, const int* m, double* w, double* Z, const int* ldz, + double* work, const int* lwork, int* iwork, int* ifail, int* info); +void chegvx_(const int* itype, const char* jobz, const char* range, const char* uplo, + const int* n, std::complex* A, const int* lda, std::complex* B, const int* ldb, + const float* vl, const float* vu, const int* il, const int* iu, + const float* abstol, const int* m, float* w, std::complex* Z, const int* ldz, + std::complex* work, const int* lwork, float* rwork, int* iwork, int* ifail, int* info); +void zhegvx_(const int* itype, const char* jobz, const char* range, const char* uplo, + const int* n, std::complex* A, const int* lda, std::complex* B, const int* ldb, + const double* vl, const double* vu, const int* il, const int* iu, + const double* abstol, const int* m, double* w, std::complex* Z, const int* ldz, + std::complex* work, const int* lwork, double* rwork, int* iwork, int* ifail, int* info); + void zhegvd_(const int* itype, const char* jobz, const char* uplo, const int* n, std::complex* a, const int* lda, const std::complex* b, const int* ldb, double* w, @@ -190,6 +230,68 @@ void hegvd(const int itype, const char jobz, const char uplo, const int n, iwork, &liwork, &info); } +// Note +// rwork is only needed for complex version +// and we include rwork in the function parameter list +// for simplicity of function overloading +// and unification of function parameter list +static inline +void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n, + float* a, const int lda, float* b, const int ldb, + const float vl, const float vu, const int il, const int iu, const float abstol, + const int m, float* w, float* z, const int ldz, + float* work, const int lwork, float* rwork, int* iwork, int* ifail, int& info) +{ + ssygvx_(&itype, &jobz, &range, &uplo, &n, + a, &lda, b, &ldb, + &vl, &vu, &il, &iu, + &abstol, &m, w, z, &ldz, + work, &lwork, iwork, ifail, &info); +} + +static inline +void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n, + double* a, const int lda, double* b, const int ldb, + const double vl, const double vu, const int il, const int iu, const double abstol, + const int m, double* w, double* z, const int ldz, + double* work, const int lwork, double* rwork, int* iwork, int* ifail, int& info) +{ + dsygvx_(&itype, &jobz, &range, &uplo, &n, + a, &lda, b, &ldb, + &vl, &vu, &il, &iu, + &abstol, &m, w, z, &ldz, + work, &lwork, iwork, ifail, &info); +} + +static inline +void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n, + std::complex* a, const int lda, std::complex* b, const int ldb, + const float vl, const float vu, const int il, const int iu, const float abstol, + const int m, float* w, std::complex* z, const int ldz, + std::complex* work, const int lwork, float* rwork, int* iwork, int* ifail, int& info) +{ + chegvx_(&itype, &jobz, &range, &uplo, &n, + a, &lda, b, &ldb, + &vl, &vu, &il, &iu, + &abstol, &m, w, z, &ldz, + work, &lwork, rwork, iwork, ifail, &info); +} + +static inline +void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n, + std::complex* a, const int lda, std::complex* b, const int ldb, + const double vl, const double vu, const int il, const int iu, const double abstol, + const int m, double* w, std::complex* z, const int ldz, + std::complex* work, const int lwork, double* rwork, int* iwork, int* ifail, int& info) +{ + zhegvx_(&itype, &jobz, &range, &uplo, &n, + a, &lda, b, &ldb, + &vl, &vu, &il, &iu, + &abstol, &m, w, z, &ldz, + work, &lwork, rwork, iwork, ifail, &info); +} + + // wrap function of fortran lapack routine zheevx. static inline void heevx( const int itype, const char jobz, const char range, const char uplo, const int n, diff --git a/source/source_base/module_external/lapack_connector.h b/source/source_base/module_external/lapack_connector.h index 1f691fe3c2..da90261dd0 100644 --- a/source/source_base/module_external/lapack_connector.h +++ b/source/source_base/module_external/lapack_connector.h @@ -1,5 +1,26 @@ -#ifndef LAPACKCONNECTOR_HPP -#define LAPACKCONNECTOR_HPP +/** + * @file lapack_connector.h + * + * @brief This is a wrapper of some LAPACK routines. + * \b Row-Major version. + * + * @warning MAY BE DEPRECATED IN THE FUTURE. + * @warning For Column-major version, please refer to \c source/source_base/module_container/base/third_party/lapack.h. + * + * @note + * !!! Note that + * This wrapper is a C++ style wrapper of LAPACK routines, + * i.e., assuming that the input matrices are in \b row-major order. + * The data layout in C++ is row-major, C style, + * while the original LAPACK is column-major, fortran style. + * (ModuleBase::ComplexMatrix is in row-major order) + * The wrapper will do the data transformation between + * row-major and column-major order automatically. + * + */ + +#ifndef LAPACK_CONNECTOR_HPP +#define LAPACK_CONNECTOR_HPP #include #include @@ -11,8 +32,10 @@ //Naming convention of lapack subroutines : ammxxx, where //"a" specifies the data type: -// - d stands for double -// - z stands for complex double +// - s stands for float +// - d stands for double +// - c stands for complex float +// - z stands for complex double //"mm" specifies the type of matrix, for example: // - he stands for hermitian // - sy stands for symmetric @@ -468,4 +491,4 @@ class LapackConnector cherk_(&uplo_changed, &trans_changed, &n, &k, &alpha, A, &lda, &beta, C, &ldc); } }; -#endif // LAPACKCONNECTOR_HPP +#endif // LAPACK_CONNECTOR_HPP diff --git a/source/source_base/module_external/lapack_wrapper.h b/source/source_base/module_external/lapack_wrapper.h index acccdc0454..a797c31f6d 100644 --- a/source/source_base/module_external/lapack_wrapper.h +++ b/source/source_base/module_external/lapack_wrapper.h @@ -1,5 +1,15 @@ #ifndef LAPACK_HPP #define LAPACK_HPP + +/// This is a wrapper of some LAPACK routines. +/// Direct wrapping of standard LAPACK routines. (column major, fortran style) +/// including: +/// 1. hegvd: compute all the eigenvalues and eigenvectors of a generalized Hermitian-definite eigenproblem +/// 2. heevx: compute the first m eigenvalues and their corresponding eigenvectors of a generalized Hermitian-definite eigenproblem +/// 3. hegvx: compute the first m eigenvalues and their corresponding eigenvectors of a generalized Hermitian-definite eigenproblem +/// 4. hegv: compute all the eigenvalues and eigenvectors of a generalized Hermitian-definite eigenproblem + + #include extern "C" { diff --git a/source/source_hsolver/kernels/hegvd_op.cpp b/source/source_hsolver/kernels/hegvd_op.cpp index 9c28866a4e..ab7c520c3d 100644 --- a/source/source_hsolver/kernels/hegvd_op.cpp +++ b/source/source_hsolver/kernels/hegvd_op.cpp @@ -1,9 +1,12 @@ #include "source_hsolver/kernels/hegvd_op.h" +#include "source_base/module_container/base/third_party/lapack.h" #include #include #include +namespace lapackConnector = container::lapackConnector; // see "source_base/module_container/base/third_party/lapack.h" + namespace hsolver { // hegvd and sygvd; dn for dense? @@ -39,7 +42,7 @@ struct hegvd_op //=========================== // calculate all eigenvalues //=========================== - LapackWrapper::xhegvd(1, + lapackConnector::hegvd(1, 'V', 'U', nstart, @@ -58,7 +61,7 @@ struct hegvd_op if (info != 0) { - std::cout << "Error: xhegvd failed, linear dependent basis functions\n" + std::cout << "Error: hegvd failed, linear dependent basis functions\n" << ", wrong initialization of wavefunction, or wavefunction information loss\n" << ", output overlap matrix scc.txt to check\n" << std::endl; @@ -82,62 +85,62 @@ struct hegvd_op } }; -template -struct hegv_op -{ - using Real = typename GetTypeReal::type; - void operator()(const base_device::DEVICE_CPU* d, - const int nbase, - const int ldh, - const T* hcc, - T* scc, - Real* eigenvalue, - T* vcc) - { - for (int i = 0; i < nbase * ldh; i++) - { - vcc[i] = hcc[i]; - } - - int info = 0; - - int lwork = 2 * nbase - 1; - T* work = new T[lwork]; - Parallel_Reduce::ZEROS(work, lwork); - - int lrwork = 3 * nbase - 2; - Real* rwork = new Real[lrwork]; - Parallel_Reduce::ZEROS(rwork, lrwork); - - //=========================== - // calculate all eigenvalues - //=========================== - LapackWrapper::xhegv(1, 'V', 'U', nbase, vcc, ldh, scc, ldh, eigenvalue, work, lwork, rwork, info); - - if (info != 0) - { - std::cout << "Error: xhegv failed, linear dependent basis functions\n" - << ", wrong initialization of wavefunction, or wavefunction information loss\n" - << ", output overlap matrix scc.txt to check\n" - << std::endl; - // print scc to file scc.txt - std::ofstream ofs("scc.txt"); - for (int i = 0; i < nbase; i++) - { - for (int j = 0; j < nbase; j++) - { - ofs << scc[i * ldh + j] << " "; - } - ofs << std::endl; - } - ofs.close(); - } - assert(0 == info); - - delete[] work; - delete[] rwork; - } -}; +// template +// struct hegv_op +// { +// using Real = typename GetTypeReal::type; +// void operator()(const base_device::DEVICE_CPU* d, +// const int nbase, +// const int ldh, +// const T* hcc, +// T* scc, +// Real* eigenvalue, +// T* vcc) +// { +// for (int i = 0; i < nbase * ldh; i++) +// { +// vcc[i] = hcc[i]; +// } + +// int info = 0; + +// int lwork = 2 * nbase - 1; +// T* work = new T[lwork]; +// Parallel_Reduce::ZEROS(work, lwork); + +// int lrwork = 3 * nbase - 2; +// Real* rwork = new Real[lrwork]; +// Parallel_Reduce::ZEROS(rwork, lrwork); + +// //=========================== +// // calculate all eigenvalues +// //=========================== +// LapackWrapper::xhegv(1, 'V', 'U', nbase, vcc, ldh, scc, ldh, eigenvalue, work, lwork, rwork, info); + +// if (info != 0) +// { +// std::cout << "Error: xhegv failed, linear dependent basis functions\n" +// << ", wrong initialization of wavefunction, or wavefunction information loss\n" +// << ", output overlap matrix scc.txt to check\n" +// << std::endl; +// // print scc to file scc.txt +// std::ofstream ofs("scc.txt"); +// for (int i = 0; i < nbase; i++) +// { +// for (int j = 0; j < nbase; j++) +// { +// ofs << scc[i * ldh + j] << " "; +// } +// ofs << std::endl; +// } +// ofs.close(); +// } +// assert(0 == info); + +// delete[] work; +// delete[] rwork; +// } +// }; // heevx and syevx /** @@ -174,7 +177,7 @@ struct heevx_op // When lwork = -1, the demension of work will be assumed // Assume the denmension of work by output work[0] - LapackWrapper::xheevx( + lapackConnector::heevx( 1, // ITYPE = 1: A*x = (lambda)*B*x 'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors. 'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found. @@ -208,7 +211,7 @@ struct heevx_op // V is the output of the function, the storage space is also (nstart * ldh), and the data size of valid V // obtained by the zhegvx operation is (nstart * nstart) and stored in zux (internal to the function). When // the function is output, the data of zux will be mapped to the corresponding position of V. - LapackWrapper::xheevx( + lapackConnector::heevx( 1, // ITYPE = 1: A*x = (lambda)*B*x 'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors. 'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found. @@ -267,7 +270,7 @@ struct hegvx_op int* iwork = new int[5 * nbase]; int* ifail = new int[nbase]; - LapackWrapper::xhegvx( + lapackConnector::hegvx( 1, // ITYPE = 1: A*x = (lambda)*B*x 'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors. 'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found. @@ -297,7 +300,7 @@ struct hegvx_op delete[] work; work = new T[lwork]; - LapackWrapper::xhegvx(1, + lapackConnector::hegvx(1, 'V', 'I', 'U', @@ -338,12 +341,12 @@ template struct heevx_op, base_device::DEVICE_CPU>; template struct hegvx_op, base_device::DEVICE_CPU>; template struct hegvx_op, base_device::DEVICE_CPU>; -template struct hegv_op, base_device::DEVICE_CPU>; -template struct hegv_op, base_device::DEVICE_CPU>; +// template struct hegv_op, base_device::DEVICE_CPU>; +// template struct hegv_op, base_device::DEVICE_CPU>; #ifdef __LCAO template struct hegvd_op; template struct heevx_op; template struct hegvx_op; -template struct hegv_op; +// template struct hegv_op; #endif } // namespace hsolver \ No newline at end of file diff --git a/source/source_hsolver/kernels/hegvd_op.h b/source/source_hsolver/kernels/hegvd_op.h index 8b440c76ee..5381f97415 100644 --- a/source/source_hsolver/kernels/hegvd_op.h +++ b/source/source_hsolver/kernels/hegvd_op.h @@ -26,7 +26,6 @@ // And will be moved to a global module(module base) later. #include "source_base/macros.h" -#include "source_base/module_external/lapack_wrapper.h" #include "source_base/parallel_reduce.h" #include "source_base/module_device/types.h" @@ -68,22 +67,22 @@ struct hegvd_op void operator()(const Device* d, const int nstart, const int ldh, const T* A, const T* B, Real* W, T* V); }; -template -struct hegv_op -{ - using Real = typename GetTypeReal::type; - /// @brief HEGV computes first m eigenvalues and eigenvectors of a complex generalized - /// Input Parameters - /// @param d : the type of device - /// @param nbase : the number of dim of the matrix - /// @param ldh : the number of dmx of the matrix - /// @param A : the hermitian matrix A in A x=lambda B x (col major) - /// @param B : the overlap matrix B in A x=lambda B x (col major) - /// Output Parameter - /// @param W : calculated eigenvalues - /// @param V : calculated eigenvectors (col major) - void operator()(const Device* d, const int nstart, const int ldh, const T* A, T* B, Real* W, T* V); -}; +// template +// struct hegv_op +// { +// using Real = typename GetTypeReal::type; +// /// @brief HEGV computes first m eigenvalues and eigenvectors of a complex generalized +// /// Input Parameters +// /// @param d : the type of device +// /// @param nbase : the number of dim of the matrix +// /// @param ldh : the number of dmx of the matrix +// /// @param A : the hermitian matrix A in A x=lambda B x (col major) +// /// @param B : the overlap matrix B in A x=lambda B x (col major) +// /// Output Parameter +// /// @param W : calculated eigenvalues +// /// @param V : calculated eigenvectors (col major) +// void operator()(const Device* d, const int nstart, const int ldh, const T* A, T* B, Real* W, T* V); +// }; template struct hegvx_op