From 934e9e99ec93b44777a7a82db5f1ecbb5d8bbb3e Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Fri, 18 Oct 2024 14:15:34 +0800 Subject: [PATCH 01/19] Link mtblas library --- CMakeLists.txt | 11 ++++ .../module_base/kernels/dsp/dsp_connector.h | 63 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 source/module_base/kernels/dsp/dsp_connector.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 637aa95d3e..9246dd1821 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ option(ENABLE_RAPIDJSON "Enable rapid-json usage." OFF) option(ENABLE_CNPY "Enable cnpy usage." OFF) option(ENABLE_PEXSI "Enable support for PEXSI." OFF) option(ENABLE_CUSOLVERMP "Enable cusolvermp." OFF) +option(ENABLE_DSP "Enable DSP usage." OFF) # enable json support if(ENABLE_RAPIDJSON) @@ -119,6 +120,12 @@ elseif(ENABLE_LCAO AND NOT ENABLE_MPI) set(ABACUS_BIN_NAME abacus_serial) endif() +if (USE_DSP) + set(USE_ELPA OFF) + set(ENABLE_LCAO OFF) + set(ABACUS_BIN_NAME abacus_dsp) +endif() + list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) if(ENABLE_COVERAGE) @@ -240,6 +247,10 @@ if(ENABLE_MPI) list(APPEND math_libs MPI::MPI_CXX) endif() +if (USE_DSP) + target_link_libraries(${ABACUS_BIN_NAME} DIR_MTBLAS_LIBRARY) +endif() + find_package(Threads REQUIRED) target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads) diff --git a/source/module_base/kernels/dsp/dsp_connector.h b/source/module_base/kernels/dsp/dsp_connector.h new file mode 100644 index 0000000000..c5801847e3 --- /dev/null +++ b/source/module_base/kernels/dsp/dsp_connector.h @@ -0,0 +1,63 @@ +#ifdef __DSP + +// Base dsp functions +void createMtblasHandle(int id); +void destroyMtblasHandle(); +void *malloc_ht(size_t bytes); +void free_ht(void* ptr); + + +// mtblas functions + +void sgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const float *alpha, const float *a, const int *lda, + const float *b, const int *ldb, const const float *beta, + const float *c, const int *ldc); + +void dgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const double *alpha,const double *a, const int *lda, + const double *b, const int *ldb, const double *beta, + const double *c, const int *ldc); + +void zgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, + const std::complex *b, const int *ldb, const std::complex *beta, + std::complex *c, const int *ldc); + +void cgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, + const std::complex *b, const int *ldb, const std::complex *beta, + std::complex *c, const int *ldc); + + +void sgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const float *alpha, const float *a, const int *lda, + const float *b, const int *ldb, const const float *beta, + const float *c, const int *ldc); + +void dgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const double *alpha,const double *a, const int *lda, + const double *b, const int *ldb, const double *beta, + const double *c, const int *ldc); + +void zgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, + const std::complex *b, const int *ldb, const std::complex *beta, + std::complex *c, const int *ldc); + +void cgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, + const std::complex *b, const int *ldb, const std::complex *beta, + std::complex *c, const int *ldc); + +//#define zgemm_ zgemm_mt + +#endif \ No newline at end of file From 2d762d9bbb1fd73bf1b60f9d9d3f6dc96e7c42d5 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Fri, 18 Oct 2024 14:17:02 +0800 Subject: [PATCH 02/19] Add mtblas gemm kernel usage --- source/module_base/blas_connector.cpp | 40 ++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 8da2b802fa..ee44f518f5 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -1,5 +1,9 @@ #include "blas_connector.h" +#ifdef __DSP +#include "module_base/kernels/dsp/dsp_connector.h" +#endif + void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { @@ -83,7 +87,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons sgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); -} + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice){ + sgemm_mt_(&transb, &transa, &n, &m, &k, + &alpha, b, &ldb, a, &lda, + &beta, c, &ldc); + } + #endif } void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, @@ -94,7 +105,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons dgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); -} + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice){ + sgemm_mt_(&transb, &transa, &n, &m, &k, + &alpha, b, &ldb, a, &lda, + &beta, c, &ldc); + } + #endif } void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, @@ -105,7 +123,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons cgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); -} + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) { + cgemm_mt_(&transb, &transa, &n, &m, &k, + &alpha, b, &ldb, a, &lda, + &beta, c, &ldc); + } + #endif } void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, @@ -116,7 +141,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons zgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); -} + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) { + zgemm_mt_(&transb, &transa, &n, &m, &k, + &alpha, b, &ldb, a, &lda, + &beta, c, &ldc); + } + #endif } void BlasConnector::gemv(const char trans, const int m, const int n, From 98fe67de729b76b802a4cbf4174c7ede05ce5200 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Fri, 18 Oct 2024 14:51:52 +0800 Subject: [PATCH 03/19] Finish memory_op on dsp --- source/module_base/module_device/memory_op.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/source/module_base/module_device/memory_op.cpp b/source/module_base/module_device/memory_op.cpp index 1edc05b8fd..5697e8c3f7 100644 --- a/source/module_base/module_device/memory_op.cpp +++ b/source/module_base/module_device/memory_op.cpp @@ -2,6 +2,9 @@ #include "module_base/memory.h" #include "module_base/tool_threading.h" +#ifdef __DSP +#include "module_base/kernels/dsp_connector.h" +#endif #include #include @@ -18,9 +21,17 @@ struct resize_memory_op { if (arr != nullptr) { +#ifdef __DSP + free_ht(arr); +#else free(arr); +#endif } +#ifdef __DSP + arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size); +#else arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size); +#endif std::string record_string; if (record_in != nullptr) { @@ -92,7 +103,11 @@ struct delete_memory_op { void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr) { +#ifdef __DSP + free_ht(arr); +#else free(arr); +#endif } }; From e4cf55c07cf82b5448a86a674c6d793dd21688ea Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Tue, 22 Oct 2024 10:38:07 +0800 Subject: [PATCH 04/19] Update CMakeLists --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9246dd1821..bacc14cc05 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,7 +39,7 @@ option(ENABLE_RAPIDJSON "Enable rapid-json usage." OFF) option(ENABLE_CNPY "Enable cnpy usage." OFF) option(ENABLE_PEXSI "Enable support for PEXSI." OFF) option(ENABLE_CUSOLVERMP "Enable cusolvermp." OFF) -option(ENABLE_DSP "Enable DSP usage." OFF) +option(USE_DSP "Enable DSP usage." OFF) # enable json support if(ENABLE_RAPIDJSON) @@ -248,7 +248,7 @@ if(ENABLE_MPI) endif() if (USE_DSP) - target_link_libraries(${ABACUS_BIN_NAME} DIR_MTBLAS_LIBRARY) + target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY}) endif() find_package(Threads REQUIRED) From fab64872d712905b0bf39ed7885ccddf5d9085ee Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Tue, 22 Oct 2024 10:51:24 +0800 Subject: [PATCH 05/19] Add compilation script --- install_dsp.sh | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 install_dsp.sh diff --git a/install_dsp.sh b/install_dsp.sh new file mode 100644 index 0000000000..f662c828fb --- /dev/null +++ b/install_dsp.sh @@ -0,0 +1,10 @@ +CXX=mpicxx \ + cmake -B build \ + -DUSE_DSP=OFF \ + -DENABLE_LCAO=OFF \ + -DFFTW3_DIR=/vol8/appsoftware/fftw/ \ + -DFFTW3_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3.so \ + -DFFTW3_OMP_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3_omp.so \ + -DFFTW3_FLOAT_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3f.so \ + -DLAPACK_DIR=/vol8/appsoftware/openblas/0.3.21/lib \ + -DDIR_MTBLAS_LIBRARY=/vol8/home/dptech_zyz1/develop/packages/libmtblas_abacus \ No newline at end of file From 400aa98fd92f3da095641653245f72128403ea3b Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Tue, 22 Oct 2024 11:00:19 +0800 Subject: [PATCH 06/19] Fix warnings --- source/module_base/blas_connector.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index ee44f518f5..16dc178711 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -68,6 +68,7 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return sdot_(&n, X, &incX, Y, &incY); + return sdot_(&n, X, &incX, Y, &incY); } } @@ -75,6 +76,7 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return ddot_(&n, X, &incX, Y, &incY); + return ddot_(&n, X, &incX, Y, &incY); } } @@ -184,6 +186,7 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return snrm2_( &n, X, &incX ); + return snrm2_( &n, X, &incX ); } } @@ -192,6 +195,7 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return dnrm2_( &n, X, &incX ); + return dnrm2_( &n, X, &incX ); } } @@ -200,6 +204,7 @@ double BlasConnector::nrm2( const int n, const std::complex *X, const in { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return dznrm2_( &n, X, &incX ); + return dznrm2_( &n, X, &incX ); } } From f29c573b80b09f85659eb4eb15dadc51b882503e Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Tue, 22 Oct 2024 12:01:47 +0800 Subject: [PATCH 07/19] Fix install script --- install_dsp.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/install_dsp.sh b/install_dsp.sh index f662c828fb..7ae2f48ffa 100644 --- a/install_dsp.sh +++ b/install_dsp.sh @@ -1,10 +1,10 @@ CXX=mpicxx \ cmake -B build \ - -DUSE_DSP=OFF \ + -DUSE_DSP=ON \ -DENABLE_LCAO=OFF \ -DFFTW3_DIR=/vol8/appsoftware/fftw/ \ -DFFTW3_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3.so \ -DFFTW3_OMP_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3_omp.so \ -DFFTW3_FLOAT_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3f.so \ -DLAPACK_DIR=/vol8/appsoftware/openblas/0.3.21/lib \ - -DDIR_MTBLAS_LIBRARY=/vol8/home/dptech_zyz1/develop/packages/libmtblas_abacus \ No newline at end of file + -DDIR_MTBLAS_LIBRARY=/vol8/home/dptech_zyz1/develop/packages/libmtblas_abacus.so \ No newline at end of file From 7b94610aa3ba421feee9f3356a3cdbd5a643d32b Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Tue, 22 Oct 2024 13:15:00 +0800 Subject: [PATCH 08/19] Initialize DSP hardware --- source/module_base/kernels/dsp/dsp_connector.h | 4 ++-- source/module_esolver/esolver_ks_pw.cpp | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/source/module_base/kernels/dsp/dsp_connector.h b/source/module_base/kernels/dsp/dsp_connector.h index c5801847e3..bb10798810 100644 --- a/source/module_base/kernels/dsp/dsp_connector.h +++ b/source/module_base/kernels/dsp/dsp_connector.h @@ -1,8 +1,8 @@ #ifdef __DSP // Base dsp functions -void createMtblasHandle(int id); -void destroyMtblasHandle(); +void dspInitHandle(int id); +void dspDestoryHandle(); void *malloc_ht(size_t bytes); void free_ht(void* ptr); diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index cd9dd4ce66..8798b808b8 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -67,6 +67,10 @@ ESolver_KS_PW::ESolver_KS_PW() container::kernels::createGpuSolverHandle(); } #endif +#ifdef __DSP + std::cout << " ** Initializing DSP Hardware..." << std::endl; + dspInitHandle(GlobalV::MY_RANK % 4); +#endif } template @@ -92,7 +96,10 @@ ESolver_KS_PW::~ESolver_KS_PW() #endif delete reinterpret_cast*>(this->kspw_psi); } - +#ifdef __DSP + std::cout << " ** Closing DSP Hardware..." << std::endl; + dspDestoryHandle(); +#endif if (PARAM.inp.precision == "single") { delete reinterpret_cast, Device>*>(this->__kspw_psi); From 28a784f85720c98e9330d966010a676c26c05177 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Tue, 22 Oct 2024 13:56:34 +0800 Subject: [PATCH 09/19] Replace gemm in math_kernel --- source/module_base/module_device/types.h | 1 + source/module_hsolver/kernels/math_kernel_op.cpp | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/source/module_base/module_device/types.h b/source/module_base/module_device/types.h index dfa960a1e3..153b6ab8ca 100644 --- a/source/module_base/module_device/types.h +++ b/source/module_base/module_device/types.h @@ -12,6 +12,7 @@ enum AbacusDevice_t UnKnown, CpuDevice, GpuDevice, + DspDevice }; } // namespace base_device diff --git a/source/module_hsolver/kernels/math_kernel_op.cpp b/source/module_hsolver/kernels/math_kernel_op.cpp index 3ad19bd4cc..ee1ec107de 100644 --- a/source/module_hsolver/kernels/math_kernel_op.cpp +++ b/source/module_hsolver/kernels/math_kernel_op.cpp @@ -277,7 +277,11 @@ struct gemm_op T* c, const int& ldc) { +#ifdef __DSP + BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc, base_device::AbacusDevice_t::DspDevice); +#else BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc); +#endif } }; From 1d8b4e79e0bc62ddd12fdec4bc4018906e06551f Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Tue, 22 Oct 2024 14:29:22 +0800 Subject: [PATCH 10/19] Fix CMakeLists Bug --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index bacc14cc05..62dfd41073 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -249,6 +249,7 @@ endif() if (USE_DSP) target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY}) + add_compile_definitions(__DSP) endif() find_package(Threads REQUIRED) From acfd3f7563c2c73cb6f133fa4d389ff1dd2fb7c6 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Tue, 22 Oct 2024 14:41:10 +0800 Subject: [PATCH 11/19] Fix bugs #1 --- source/module_base/kernels/dsp/dsp_connector.h | 7 +++++-- source/module_base/module_device/memory_op.cpp | 2 +- source/module_esolver/esolver_ks_pw.cpp | 4 ++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/source/module_base/kernels/dsp/dsp_connector.h b/source/module_base/kernels/dsp/dsp_connector.h index bb10798810..f3849b3dcd 100644 --- a/source/module_base/kernels/dsp/dsp_connector.h +++ b/source/module_base/kernels/dsp/dsp_connector.h @@ -1,3 +1,5 @@ +#ifndef DSP_CONNECTOR_H +#define DSP_CONNECTOR_H #ifdef __DSP // Base dsp functions @@ -12,7 +14,7 @@ void free_ht(void* ptr); void sgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, - const float *b, const int *ldb, const const float *beta, + const float *b, const int *ldb, const float *beta, const float *c, const int *ldc); void dgemm_mt_(const char *transa, const char *transb, @@ -37,7 +39,7 @@ void cgemm_mt_(const char *transa, const char *transb, void sgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, - const float *b, const int *ldb, const const float *beta, + const float *b, const int *ldb, const float *beta, const float *c, const int *ldc); void dgemm_mth_(const char *transa, const char *transb, @@ -60,4 +62,5 @@ void cgemm_mth_(const char *transa, const char *transb, //#define zgemm_ zgemm_mt +#endif #endif \ No newline at end of file diff --git a/source/module_base/module_device/memory_op.cpp b/source/module_base/module_device/memory_op.cpp index 5697e8c3f7..625b535051 100644 --- a/source/module_base/module_device/memory_op.cpp +++ b/source/module_base/module_device/memory_op.cpp @@ -3,7 +3,7 @@ #include "module_base/memory.h" #include "module_base/tool_threading.h" #ifdef __DSP -#include "module_base/kernels/dsp_connector.h" +#include "module_base/kernels/dsp/dsp_connector.h" #endif #include diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 8798b808b8..bf6c0bc450 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -49,6 +49,10 @@ #include #include +#ifdef __DSP +#include "module_base/kernels/dsp/dsp_connector.h" +#endif + namespace ModuleESolver { From c859d5480309fffdbb1cc8578ad0926f84feebf9 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Tue, 22 Oct 2024 14:50:04 +0800 Subject: [PATCH 12/19] Fix bug 2 --- source/module_base/blas_connector.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 16dc178711..075e4df297 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -110,7 +110,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - sgemm_mt_(&transb, &transa, &n, &m, &k, + dgemm_mt_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); } From 687953b40b330fe981e6ee1d0310eff1171be3d8 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Wed, 23 Oct 2024 12:39:49 +0800 Subject: [PATCH 13/19] Fix link to shared library error --- source/module_base/kernels/dsp/dsp_connector.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/source/module_base/kernels/dsp/dsp_connector.h b/source/module_base/kernels/dsp/dsp_connector.h index f3849b3dcd..2d3075fcd1 100644 --- a/source/module_base/kernels/dsp/dsp_connector.h +++ b/source/module_base/kernels/dsp/dsp_connector.h @@ -15,13 +15,13 @@ void sgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, - const float *c, const int *ldc); + float *c, const int *ldc); void dgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha,const double *a, const int *lda, const double *b, const int *ldb, const double *beta, - const double *c, const int *ldc); + double *c, const int *ldc); void zgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, @@ -40,13 +40,13 @@ void sgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, - const float *c, const int *ldc); + float *c, const int *ldc); void dgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha,const double *a, const int *lda, const double *b, const int *ldb, const double *beta, - const double *c, const int *ldc); + double *c, const int *ldc); void zgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, From a6668dde0817f8a562b605d17eee7de71d95626e Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Wed, 23 Oct 2024 18:55:35 +0800 Subject: [PATCH 14/19] Stop use gemm_mt globally --- source/module_hsolver/kernels/math_kernel_op.cpp | 10 ++++++++-- source/module_hsolver/kernels/math_kernel_op.h | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/source/module_hsolver/kernels/math_kernel_op.cpp b/source/module_hsolver/kernels/math_kernel_op.cpp index ee1ec107de..4ba05a2398 100644 --- a/source/module_hsolver/kernels/math_kernel_op.cpp +++ b/source/module_hsolver/kernels/math_kernel_op.cpp @@ -275,10 +275,16 @@ struct gemm_op const int& ldb, const T* beta, T* c, - const int& ldc) + const int& ldc, + bool use_dsp) { #ifdef __DSP - BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc, base_device::AbacusDevice_t::DspDevice); + if (use_dsp){ + BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc, base_device::AbacusDevice_t::DspDevice); + } + else{ + BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc); + } #else BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc); #endif diff --git a/source/module_hsolver/kernels/math_kernel_op.h b/source/module_hsolver/kernels/math_kernel_op.h index a23c9c329f..40f5a5e83d 100644 --- a/source/module_hsolver/kernels/math_kernel_op.h +++ b/source/module_hsolver/kernels/math_kernel_op.h @@ -261,7 +261,7 @@ template struct gemm_op { void operator()(const Device *d, const char &transa, const char &transb, const int &m, const int &n, const int &k, const T *alpha, const T *a, const int &lda, const T *b, const int &ldb, - const T *beta, T *c, const int &ldc); + const T *beta, T *c, const int &ldc, bool usd_dsp = false); }; template struct matrixTranspose_op { From 5dcbaa6b6f7d1f16fa3c0f573a13a48ac64fa7ea Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Wed, 23 Oct 2024 19:01:45 +0800 Subject: [PATCH 15/19] Modify op usage --- source/module_hsolver/diago_dav_subspace.cpp | 7 +++- .../module_hsolver/kernels/math_kernel_op.cpp | 36 +++++++++++++------ .../module_hsolver/kernels/math_kernel_op.h | 32 ++++++++++++++++- 3 files changed, 62 insertions(+), 13 deletions(-) diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index 1bfd0a73a1..cb576a30ea 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -262,7 +262,12 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, } } - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'N', 'N', this->dim, diff --git a/source/module_hsolver/kernels/math_kernel_op.cpp b/source/module_hsolver/kernels/math_kernel_op.cpp index 4ba05a2398..1d9a579bd0 100644 --- a/source/module_hsolver/kernels/math_kernel_op.cpp +++ b/source/module_hsolver/kernels/math_kernel_op.cpp @@ -275,22 +275,36 @@ struct gemm_op const int& ldb, const T* beta, T* c, - const int& ldc, - bool use_dsp) + const int& ldc,) { -#ifdef __DSP - if (use_dsp){ - BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc, base_device::AbacusDevice_t::DspDevice); - } - else{ - BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc); - } -#else BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc); -#endif } }; +#ifdef __DSP +template +struct gemm_op_mt +{ + void operator()(const base_device::DEVICE_CPU* /*ctx*/, + const char& transa, + const char& transb, + const int& m, + const int& n, + const int& k, + const T* alpha, + const T* a, + const int& lda, + const T* b, + const int& ldb, + const T* beta, + T* c, + const int& ldc) + { + BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc, base_device::AbacusDevice_t::DspDevice); + } +}; +#endif + template struct matrixTranspose_op { diff --git a/source/module_hsolver/kernels/math_kernel_op.h b/source/module_hsolver/kernels/math_kernel_op.h index 40f5a5e83d..0daf0e5718 100644 --- a/source/module_hsolver/kernels/math_kernel_op.h +++ b/source/module_hsolver/kernels/math_kernel_op.h @@ -261,9 +261,39 @@ template struct gemm_op { void operator()(const Device *d, const char &transa, const char &transb, const int &m, const int &n, const int &k, const T *alpha, const T *a, const int &lda, const T *b, const int &ldb, - const T *beta, T *c, const int &ldc, bool usd_dsp = false); + const T *beta, T *c, const int &ldc); }; +#ifdef __DSP +// compute C = alpha * op(A) * op(B) + beta * C on DSP Hardware +template struct gemm_op_mt { + /// @brief C = alpha * op(A) * op(B) + beta * C + /// + /// Input Parameters + /// \param d : the type of computing device + /// \param transa : whether to transpose matrix A + /// \param transb : whether to transpose matrix B + /// \param m : first dimension of matrix mulplication + /// \param n : second dimension of matrix mulplication + /// \param k : third dimension of matrix mulplication + /// \param alpha : input constant alpha + /// \param a : input matrix A + /// \param lda : leading dimention of A + /// \param b : input matrix B + /// \param ldb : leading dimention of A + /// \param beta : input constant beta + /// \param c : input matrix C + /// \param ldc : leading dimention of C + /// + /// Output Parameters + /// \param c : output matrix C + void operator()(const Device *d, const char &transa, const char &transb, + const int &m, const int &n, const int &k, const T *alpha, + const T *a, const int &lda, const T *b, const int &ldb, + const T *beta, T *c, const int &ldc); +}; +#endif + template struct matrixTranspose_op { /// @brief transpose the input matrix /// From 0160db3a27c587025a62e4ea9fecf16dc9e494e0 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Wed, 23 Oct 2024 19:14:58 +0800 Subject: [PATCH 16/19] Fix bug --- source/module_hsolver/kernels/math_kernel_op.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_hsolver/kernels/math_kernel_op.cpp b/source/module_hsolver/kernels/math_kernel_op.cpp index 1d9a579bd0..06c5ddbacd 100644 --- a/source/module_hsolver/kernels/math_kernel_op.cpp +++ b/source/module_hsolver/kernels/math_kernel_op.cpp @@ -275,7 +275,7 @@ struct gemm_op const int& ldb, const T* beta, T* c, - const int& ldc,) + const int& ldc) { BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc); } From 94b57bb95a925dc4a1b7cd753ef0d04df7e6c7e6 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Wed, 23 Oct 2024 22:23:52 +0800 Subject: [PATCH 17/19] Fix template usage --- source/module_hsolver/kernels/math_kernel_op.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/source/module_hsolver/kernels/math_kernel_op.cpp b/source/module_hsolver/kernels/math_kernel_op.cpp index 06c5ddbacd..9bf4f04138 100644 --- a/source/module_hsolver/kernels/math_kernel_op.cpp +++ b/source/module_hsolver/kernels/math_kernel_op.cpp @@ -360,6 +360,7 @@ template struct scal_op; template struct axpy_op, base_device::DEVICE_CPU>; template struct gemv_op, base_device::DEVICE_CPU>; template struct gemm_op, base_device::DEVICE_CPU>; +template struct gemm_op_mt, base_device::DEVICE_CPU>; template struct dot_real_op, base_device::DEVICE_CPU>; template struct vector_div_constant_op, base_device::DEVICE_CPU>; template struct vector_mul_vector_op, base_device::DEVICE_CPU>; @@ -374,6 +375,7 @@ template struct scal_op; template struct axpy_op, base_device::DEVICE_CPU>; template struct gemv_op, base_device::DEVICE_CPU>; template struct gemm_op, base_device::DEVICE_CPU>; +template struct gemm_op_mt, base_device::DEVICE_CPU>; template struct dot_real_op, base_device::DEVICE_CPU>; template struct vector_div_constant_op, base_device::DEVICE_CPU>; template struct vector_mul_vector_op, base_device::DEVICE_CPU>; From b64fc01a0df960fe85215a27114c3733d4140908 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Wed, 23 Oct 2024 22:28:31 +0800 Subject: [PATCH 18/19] Fix compilation --- source/module_hsolver/kernels/math_kernel_op.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/source/module_hsolver/kernels/math_kernel_op.cpp b/source/module_hsolver/kernels/math_kernel_op.cpp index 9bf4f04138..02deb41696 100644 --- a/source/module_hsolver/kernels/math_kernel_op.cpp +++ b/source/module_hsolver/kernels/math_kernel_op.cpp @@ -360,7 +360,6 @@ template struct scal_op; template struct axpy_op, base_device::DEVICE_CPU>; template struct gemv_op, base_device::DEVICE_CPU>; template struct gemm_op, base_device::DEVICE_CPU>; -template struct gemm_op_mt, base_device::DEVICE_CPU>; template struct dot_real_op, base_device::DEVICE_CPU>; template struct vector_div_constant_op, base_device::DEVICE_CPU>; template struct vector_mul_vector_op, base_device::DEVICE_CPU>; @@ -375,7 +374,6 @@ template struct scal_op; template struct axpy_op, base_device::DEVICE_CPU>; template struct gemv_op, base_device::DEVICE_CPU>; template struct gemm_op, base_device::DEVICE_CPU>; -template struct gemm_op_mt, base_device::DEVICE_CPU>; template struct dot_real_op, base_device::DEVICE_CPU>; template struct vector_div_constant_op, base_device::DEVICE_CPU>; template struct vector_mul_vector_op, base_device::DEVICE_CPU>; @@ -398,4 +396,8 @@ template struct matrixTranspose_op; template struct matrixSetToAnother; template struct constantvector_addORsub_constantVector_op; #endif +#ifdef __DSP +template struct gemm_op_mt, base_device::DEVICE_CPU>; +template struct gemm_op_mt, base_device::DEVICE_CPU>; +#endif } // namespace hsolver \ No newline at end of file From 6c320fec824bfbbb57e9cec76a53698e717c23f0 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Wed, 23 Oct 2024 22:36:45 +0800 Subject: [PATCH 19/19] Replace all dav_subspace gemm kernels --- source/module_hsolver/diago_dav_subspace.cpp | 35 +++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index cb576a30ea..7d298be7ac 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -181,7 +181,12 @@ int Diago_DavSubspace::diag_once(const HPsiFunc& hpsi_func, // updata eigenvectors of Hamiltonian setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax); - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'N', 'N', this->dim, @@ -307,7 +312,12 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, delmem_real_op()(this->ctx, e_temp_hd); } - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'N', 'N', this->dim, @@ -391,7 +401,12 @@ void Diago_DavSubspace::cal_elem(const int& dim, { ModuleBase::timer::tick("Diago_DavSubspace", "cal_elem"); - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'C', 'N', nbase + notconv, @@ -406,7 +421,12 @@ void Diago_DavSubspace::cal_elem(const int& dim, &hcc[nbase * this->nbase_x], this->nbase_x); - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'C', 'N', nbase + notconv, @@ -608,7 +628,12 @@ void Diago_DavSubspace::refresh(const int& dim, { ModuleBase::timer::tick("Diago_DavSubspace", "refresh"); - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'N', 'N', this->dim,