diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5745173cc1..0ccee1f673 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -449,10 +449,28 @@ if(USE_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=${OpenMP_CXX_FLAGS}" CACHE STRING "CUDA flags" FORCE)
endif()
if (ENABLE_CUSOLVERMP)
- # Keep cuSolverMp discovery/linking logic in a dedicated module.
+ # Keep cuSOLVERMp discovery/linking logic in a dedicated module.
include(cmake/SetupCuSolverMp.cmake)
abacus_setup_cusolvermp(${ABACUS_BIN_NAME})
endif()
+ if (ENABLE_CUBLASMP)
+ # Enforcement 1: cuBLASMp requires cuSOLVERMp to be enabled
+ if (NOT ENABLE_CUSOLVERMP)
+ message(FATAL_ERROR
+ "ENABLE_CUBLASMP is set to ON, but ENABLE_CUSOLVERMP is OFF."
+ "In ABACUS, cuBLASMp support requires cuSOLVERMp to be enabled simultaneously."
+ "Please set -DENABLE_CUSOLVERMP=ON.")
+ endif()
+ # Enforcement 2: cuBLASMp 0.8.0+ is incompatible with CAL backend
+ # Note: _use_cal is defined inside abacus_setup_cusolvermp
+ if (_use_cal)
+ message(FATAL_ERROR
+ "cuBLASMp 0.8.0+ requires NCCL Symmetric Memory, but cuSOLVERMp is using CAL backend."
+ "Please upgrade cuSOLVERMp to >= 0.7.0 to use NCCL for both.")
+ endif()
+ include(cmake/SetupCuBlasMp.cmake)
+ abacus_setup_cublasmp(${ABACUS_BIN_NAME})
+ endif()
endif()
endif()
diff --git a/cmake/SetupCuBlasMp.cmake b/cmake/SetupCuBlasMp.cmake
new file mode 100644
index 0000000000..7937a02936
--- /dev/null
+++ b/cmake/SetupCuBlasMp.cmake
@@ -0,0 +1,78 @@
+# =============================================================================
+# Configure cuBLASMp dependencies and linking for ABACUS
+# =============================================================================
+
+include_guard(GLOBAL)
+
+function(abacus_setup_cublasmp target_name)
+ add_compile_definitions(__CUBLASMP)
+
+ # 1. Search for cuBLASMp library and header files
+ # libcublasmp.so
+ find_library(CUBLASMP_LIBRARY NAMES cublasmp
+ HINTS ${CUBLASMP_PATH} ${NVHPC_ROOT_DIR}
+ PATH_SUFFIXES lib lib64 math_libs/lib math_libs/lib64)
+
+ # cublasmp.h
+ find_path(CUBLASMP_INCLUDE_DIR NAMES cublasmp.h
+ HINTS ${CUBLASMP_PATH} ${NVHPC_ROOT_DIR}
+ PATH_SUFFIXES include math_libs/include)
+
+ if(NOT CUBLASMP_LIBRARY OR NOT CUBLASMP_INCLUDE_DIR)
+ message(FATAL_ERROR
+ "cuBLASMp not found. Please ensure CUBLASMP_PATH is set correctly."
+ )
+ endif()
+
+ message(STATUS "Found cuBLASMp: ${CUBLASMP_LIBRARY}")
+
+ # 2. Version validation by parsing header macros
+ set(CUBLASMP_VERSION_STR "")
+ set(CUBLASMP_VERSION_HEADER "${CUBLASMP_INCLUDE_DIR}/cublasmp.h")
+
+ if(EXISTS "${CUBLASMP_VERSION_HEADER}")
+ # Extract version lines using regular expressions from cublasmp.h
+ file(STRINGS "${CUBLASMP_VERSION_HEADER}" CUBLASMP_MAJOR_LINE
+ REGEX "^#define[ \t]+CUBLASMP_VER_MAJOR[ \t]+[0-9]+")
+ file(STRINGS "${CUBLASMP_VERSION_HEADER}" CUBLASMP_MINOR_LINE
+ REGEX "^#define[ \t]+CUBLASMP_VER_MINOR[ \t]+[0-9]+")
+ file(STRINGS "${CUBLASMP_VERSION_HEADER}" CUBLASMP_PATCH_LINE
+ REGEX "^#define[ \t]+CUBLASMP_VER_PATCH[ \t]+[0-9]+")
+
+ # Extract numeric values from the matched strings
+ string(REGEX MATCH "([0-9]+)" CUBLASMP_VER_MAJOR "${CUBLASMP_MAJOR_LINE}")
+ string(REGEX MATCH "([0-9]+)" CUBLASMP_VER_MINOR "${CUBLASMP_MINOR_LINE}")
+ string(REGEX MATCH "([0-9]+)" CUBLASMP_VER_PATCH "${CUBLASMP_PATCH_LINE}")
+
+ if(NOT CUBLASMP_VER_MAJOR STREQUAL ""
+ AND NOT CUBLASMP_VER_MINOR STREQUAL ""
+ AND NOT CUBLASMP_VER_PATCH STREQUAL "")
+ set(CUBLASMP_VERSION_STR
+ "${CUBLASMP_VER_MAJOR}.${CUBLASMP_VER_MINOR}.${CUBLASMP_VER_PATCH}")
+ endif()
+ endif()
+
+ message(STATUS "Detected cuBLASMp version: ${CUBLASMP_VERSION_STR}")
+
+ # 3. Version constraint: ABACUS requires cuBLASMp >= 0.8.0
+ if(CUBLASMP_VERSION_STR AND CUBLASMP_VERSION_STR VERSION_LESS "0.8.0")
+ message(FATAL_ERROR
+ "cuBLASMp version ${CUBLASMP_VERSION_STR} is too old. "
+ "ABACUS requires cuBLASMp >= 0.8.0 for NCCL Symmetric Memory support."
+ )
+ elseif(NOT CUBLASMP_VERSION_STR)
+ message(WARNING "Could not detect cuBLASMp version. Proceeding cautiously.")
+ endif()
+
+ # 4. Create cublasMp::cublasMp imported target
+ if(NOT TARGET cublasMp::cublasMp)
+ add_library(cublasMp::cublasMp IMPORTED INTERFACE)
+ set_target_properties(cublasMp::cublasMp PROPERTIES
+ INTERFACE_LINK_LIBRARIES "${CUBLASMP_LIBRARY};NCCL::NCCL"
+ INTERFACE_INCLUDE_DIRECTORIES "${CUBLASMP_INCLUDE_DIR}")
+ endif()
+
+ # 5. Link the library to the target
+ target_link_libraries(${target_name} cublasMp::cublasMp)
+
+endfunction()
diff --git a/cmake/SetupCuSolverMp.cmake b/cmake/SetupCuSolverMp.cmake
index a2bcd00bdf..004665686b 100644
--- a/cmake/SetupCuSolverMp.cmake
+++ b/cmake/SetupCuSolverMp.cmake
@@ -1,5 +1,5 @@
# =============================================================================
-# Configure cuSolverMp dependencies and linking for ABACUS
+# Configure cuSOLVERMp dependencies and linking for ABACUS
# =============================================================================
include_guard(GLOBAL)
@@ -7,7 +7,7 @@ include_guard(GLOBAL)
function(abacus_setup_cusolvermp target_name)
add_compile_definitions(__CUSOLVERMP)
- # Find cuSolverMp first, then decide communicator backend.
+ # Find cuSOLVERMp first, then decide communicator backend.
find_library(CUSOLVERMP_LIBRARY NAMES cusolverMp
HINTS ${CAL_CUSOLVERMP_PATH} ${NVHPC_ROOT_DIR}
PATH_SUFFIXES lib lib64 math_libs/lib math_libs/lib64)
@@ -18,11 +18,11 @@ function(abacus_setup_cusolvermp target_name)
if(NOT CUSOLVERMP_LIBRARY OR NOT CUSOLVERMP_INCLUDE_DIR)
message(FATAL_ERROR
- "cusolverMp not found. Set CUSOLVERMP_PATH or NVHPC_ROOT_DIR."
+ "cuSOLVERMp not found. Set CUSOLVERMP_PATH or NVHPC_ROOT_DIR."
)
endif()
- message(STATUS "Found cusolverMp: ${CUSOLVERMP_LIBRARY}")
+ message(STATUS "Found cuSOLVERMp: ${CUSOLVERMP_LIBRARY}")
set(CUSOLVERMP_VERSION_STR "")
set(CUSOLVERMP_VERSION_HEADER "${CUSOLVERMP_INCLUDE_DIR}/cusolverMp.h")
@@ -47,27 +47,30 @@ function(abacus_setup_cusolvermp target_name)
# Check minimum version requirement (>= 0.4.0)
if(CUSOLVERMP_VERSION_STR AND CUSOLVERMP_VERSION_STR VERSION_LESS "0.4.0")
message(FATAL_ERROR
- "cuSolverMp version ${CUSOLVERMP_VERSION_STR} is too old. "
- "ABACUS requires cuSolverMp >= 0.4.0 (NVIDIA HPC SDK >= 23.5). "
+ "cuSOLVERMp version ${CUSOLVERMP_VERSION_STR} is too old. "
+ "ABACUS requires cuSOLVERMp >= 0.4.0 (NVIDIA HPC SDK >= 23.5). "
"Please upgrade your NVIDIA HPC SDK installation."
)
endif()
- # Auto-select communicator backend by cuSolverMp version.
- # cuSolverMp < 0.7.0 -> CAL, otherwise -> NCCL.
+ # Auto-select communicator backend by cuSOLVERMp version.
+ # cuSOLVERMp < 0.7.0 -> CAL, otherwise -> NCCL.
set(_use_cal OFF)
if(CUSOLVERMP_VERSION_STR AND CUSOLVERMP_VERSION_STR VERSION_LESS "0.7.0")
set(_use_cal ON)
message(STATUS
- "Detected cuSolverMp ${CUSOLVERMP_VERSION_STR} (< 0.7.0). Using CAL backend.")
+ "Detected cuSOLVERMp ${CUSOLVERMP_VERSION_STR} (< 0.7.0). Using CAL backend.")
elseif(CUSOLVERMP_VERSION_STR)
message(STATUS
- "Detected cuSolverMp ${CUSOLVERMP_VERSION_STR} (>= 0.7.0). Using NCCL backend.")
+ "Detected cuSOLVERMp ${CUSOLVERMP_VERSION_STR} (>= 0.7.0). Using NCCL backend.")
elseif(NOT CUSOLVERMP_VERSION_STR)
message(WARNING
- "Unable to detect cuSolverMp version from header. Using NCCL backend by default.")
+ "Unable to detect cuSOLVERMp version from header. Using NCCL backend by default.")
endif()
+ # Raise the variable to the caller's scope
+ set(_use_cal ${_use_cal} PARENT_SCOPE)
+
# Backend selection:
# - _use_cal=ON -> cal communicator backend
# - _use_cal=OFF -> NCCL communicator backend
diff --git a/source/source_esolver/esolver_ks_lcao_tddft.cpp b/source/source_esolver/esolver_ks_lcao_tddft.cpp
index 2e463acdcd..de109583a1 100644
--- a/source/source_esolver/esolver_ks_lcao_tddft.cpp
+++ b/source/source_esolver/esolver_ks_lcao_tddft.cpp
@@ -30,7 +30,10 @@ ESolver_KS_LCAO_TDDFT
::ESolver_KS_LCAO_TDDFT()
if (ct_device_type == ct::DeviceType::GpuDevice)
{
use_tensor = true;
- use_lapack = true;
+ if (PARAM.inp.ks_solver != "cusolvermp")
+ {
+ use_lapack = true;
+ }
}
}
@@ -235,21 +238,22 @@ void ESolver_KS_LCAO_TDDFT
::hamilt2rho_single(UnitCell& ucell,
{
if (istep >= TD_info::estep_shift + 1)
{
- module_rt::Evolve_elec::solve_psi(istep,
- PARAM.inp.nbands,
- PARAM.globalv.nlocal,
- this->kv.get_nks(),
- static_cast>*>(this->p_hamilt),
- this->pv,
- this->psi,
- this->psi_laststep,
- this->Hk_laststep,
- this->Sk_laststep,
- this->pelec->ekb,
- GlobalV::ofs_running,
- PARAM.inp.propagator,
- use_tensor,
- use_lapack);
+ module_rt::Evolve_elec::solve_psi(
+ istep,
+ PARAM.inp.nbands,
+ PARAM.globalv.nlocal,
+ this->kv.get_nks(),
+ static_cast>*>(this->p_hamilt),
+ this->pv,
+ this->psi,
+ this->psi_laststep,
+ this->Hk_laststep,
+ this->Sk_laststep,
+ this->pelec->ekb,
+ GlobalV::ofs_running,
+ PARAM.inp.propagator,
+ use_tensor,
+ use_lapack);
}
this->weight_dm_rho(ucell);
}
@@ -346,11 +350,18 @@ void ESolver_KS_LCAO_TDDFT::iter_finish(UnitCell& ucell,
{
if (use_tensor && use_lapack)
{
- elecstate::cal_edm_tddft_tensor_lapack(this->pv, this->dmat, this->kv, static_cast>*>(this->p_hamilt));
+ elecstate::cal_edm_tddft_tensor_lapack(
+ this->pv,
+ this->dmat,
+ this->kv,
+ static_cast>*>(this->p_hamilt));
}
else
{
- elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, static_cast>*>(this->p_hamilt));
+ elecstate::cal_edm_tddft(this->pv,
+ this->dmat,
+ this->kv,
+ static_cast>*>(this->p_hamilt));
}
}
}
diff --git a/source/source_esolver/esolver_ks_lcao_tddft.h b/source/source_esolver/esolver_ks_lcao_tddft.h
index 53e6ac77f5..f534b303f4 100644
--- a/source/source_esolver/esolver_ks_lcao_tddft.h
+++ b/source/source_esolver/esolver_ks_lcao_tddft.h
@@ -3,10 +3,11 @@
#include "esolver_ks.h"
#include "esolver_ks_lcao.h"
#include "source_base/module_container/ATen/core/tensor.h" // ct::Tensor
-#include "source_lcao/module_rt/gather_mat.h" // MPI gathering and distributing functions
+#include "source_lcao/module_rt/boundary_fix.h"
+#include "source_lcao/module_rt/gather_mat.h" // MPI gathering and distributing functions
+#include "source_lcao/module_rt/kernels/cublasmp_context.h"
#include "source_lcao/module_rt/td_info.h"
#include "source_lcao/module_rt/velocity_op.h"
-#include "source_lcao/module_rt/boundary_fix.h"
namespace ModuleESolver
{
@@ -51,6 +52,7 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO, TR>
//! Control heterogeneous computing of the TDDFT solver
bool use_tensor = false;
bool use_lapack = false;
+ CublasMpResources cublas_res;
// Control the device type for Hk_laststep and Sk_laststep
// Set to CPU temporarily, should wait for further GPU development
diff --git a/source/source_lcao/module_rt/CMakeLists.txt b/source/source_lcao/module_rt/CMakeLists.txt
index 35a65a5aa7..a1056c3347 100644
--- a/source/source_lcao/module_rt/CMakeLists.txt
+++ b/source/source_lcao/module_rt/CMakeLists.txt
@@ -22,6 +22,8 @@ if(ENABLE_LCAO)
list(APPEND objects
kernels/cuda/snap_psibeta_kernel.cu
kernels/cuda/snap_psibeta_gpu.cu
+ kernels/cuda/norm_psi_kernel.cu
+ kernels/cuda/band_energy_kernel.cu
)
endif()
diff --git a/source/source_lcao/module_rt/band_energy.cpp b/source/source_lcao/module_rt/band_energy.cpp
index 63b2fad71b..601840b285 100644
--- a/source/source_lcao/module_rt/band_energy.cpp
+++ b/source/source_lcao/module_rt/band_energy.cpp
@@ -4,6 +4,10 @@
#include "source_base/module_container/ATen/kernels/blas.h"
#include "source_base/module_external/scalapack_connector.h"
+#ifdef __CUBLASMP
+#include "kernels/cuda/band_energy_kernel.cuh"
+#endif
+
#include
#include
@@ -165,140 +169,226 @@ void compute_ekb_tensor(const Parallel_Orbitals* pv,
const ct::Tensor& Htmp,
const ct::Tensor& psi_k,
ct::Tensor& ekb,
- std::ofstream& ofs_running)
+ std::ofstream& ofs_running,
+ CublasMpResources& cublas_res)
{
- assert(pv->nloc_wfc > 0 && pv->nloc > 0);
-
- // Create Tensor objects for temporary data
- ct::Tensor tmp1(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({pv->nloc_wfc}));
- tmp1.zero();
-
- ct::Tensor eij(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({pv->nloc}));
- eij.zero();
-
- // Perform matrix multiplication: tmp1 = Htmp * psi_k
- ScalapackConnector::gemm('N',
- 'N',
- nlocal,
- nband,
- nlocal,
- 1.0,
- Htmp.data>(),
- 1,
- 1,
- pv->desc,
- psi_k.data>(),
- 1,
- 1,
- pv->desc_wfc,
- 0.0,
- tmp1.data>(),
- 1,
- 1,
- pv->desc_wfc);
-
- // Perform matrix multiplication: eij = psi_k^dagger * tmp1
- ScalapackConnector::gemm('C',
- 'N',
- nband,
- nband,
- nlocal,
- 1.0,
- psi_k.data>(),
- 1,
- 1,
- pv->desc_wfc,
- tmp1.data>(),
- 1,
- 1,
- pv->desc_wfc,
- 0.0,
- eij.data>(),
- 1,
- 1,
- pv->desc_Eij);
-
- if (PARAM.inp.td_print_eij >= 0.0)
+#ifdef __CUBLASMP
+ // 1. Resource validation
+ if (!cublas_res.is_initialized || cublas_res.cublasmp_grid == nullptr)
{
- ofs_running
- << "------------------------------------------------------------------------------------------------"
- << std::endl;
- ofs_running << " Eij:" << std::endl;
- for (int i = 0; i < pv->nrow_bands; i++)
- {
- const int in = i * pv->ncol;
- for (int j = 0; j < pv->ncol_bands; j++)
- {
- double aa = eij.data>()[in + j].real();
- double bb = eij.data>()[in + j].imag();
- if (std::abs(aa) < PARAM.inp.td_print_eij)
- {
- aa = 0.0;
- }
- if (std::abs(bb) < PARAM.inp.td_print_eij)
- {
- bb = 0.0;
- }
- if (std::abs(aa) > 0.0 || std::abs(bb) > 0.0)
- {
- std::streamsize original_precision = ofs_running.precision();
- ofs_running << std::fixed << std::setprecision(8);
- ofs_running << "i = " << std::setw(2) << i << ", j = " << std::setw(2) << j
- << ", Eij = " << std::setw(12) << aa << " + " << std::setw(12) << bb << " i"
- << std::endl;
- ofs_running.unsetf(std::ios_base::fixed);
- ofs_running.precision(original_precision);
- }
- }
- }
- ofs_running << std::endl;
- ofs_running
- << "------------------------------------------------------------------------------------------------"
- << std::endl;
+ return;
}
- int info = 0;
- int naroc[2] = {0, 0};
+ assert(pv->nloc_wfc > 0 && pv->nloc > 0);
+ assert(Htmp.device_type() == ct::DeviceType::GpuDevice);
+ assert(psi_k.device_type() == ct::DeviceType::GpuDevice);
+ assert(ekb.device_type() == ct::DeviceType::GpuDevice);
- // Create a Tensor for eii
- assert(nband > 0);
- ct::Tensor eii(ct::DataType::DT_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({nband}));
- eii.zero();
+ // 2. Data Pointers
+ void* d_H = static_cast(const_cast*>(Htmp.data>()));
+ void* d_Psi = static_cast(const_cast*>(psi_k.data>()));
- for (int iprow = 0; iprow < pv->dim0; ++iprow)
- {
- for (int ipcol = 0; ipcol < pv->dim1; ++ipcol)
- {
- if (iprow == pv->coord[0] && ipcol == pv->coord[1])
- {
- naroc[0] = pv->nrow;
- naroc[1] = pv->ncol;
- for (int j = 0; j < naroc[1]; ++j)
- {
- int igcol = globalIndex(j, pv->nb, pv->dim1, ipcol);
- if (igcol >= nband)
- {
- continue;
- }
- for (int i = 0; i < naroc[0]; ++i)
- {
- int igrow = globalIndex(i, pv->nb, pv->dim0, iprow);
- if (igrow >= nband)
- {
- continue;
- }
- if (igcol == igrow)
- {
- eii.data()[igcol] = eij.data>()[j * naroc[0] + i].real();
- }
- }
- }
- }
- } // loop ipcol
- } // loop iprow
+ int64_t psi_elems = psi_k.NumElements();
+ ct::Tensor Tmp1_gpu(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::GpuDevice, ct::TensorShape({psi_elems}));
+ void* d_Tmp1 = static_cast(Tmp1_gpu.data>());
+
+ int64_t eij_elems = pv->nloc;
+ ct::Tensor Eij_gpu(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::GpuDevice, ct::TensorShape({eij_elems}));
+ void* d_Eij = static_cast(Eij_gpu.data>());
+
+ std::complex alpha = {1.0, 0.0};
+ std::complex beta = {0.0, 0.0};
- // Perform MPI reduction to compute ekb
- info = MPI_Allreduce(eii.data(), ekb.data(), nband, MPI_DOUBLE, MPI_SUM, pv->comm());
+ // 3. Matrix Descriptors Creation
+ cublasMpMatrixDescriptor_t desc_H, desc_Psi, desc_Eij;
+
+ // H descriptor: nlocal x nlocal
+ cublasMpMatrixDescriptorCreate(pv->desc[2],
+ pv->desc[3],
+ pv->desc[4],
+ pv->desc[5],
+ 0,
+ 0,
+ pv->desc[8],
+ CUDA_C_64F,
+ cublas_res.cublasmp_grid,
+ &desc_H);
+
+ // Psi descriptor: nlocal x nband
+ cublasMpMatrixDescriptorCreate(pv->desc_wfc[2],
+ pv->desc_wfc[3],
+ pv->desc_wfc[4],
+ pv->desc_wfc[5],
+ 0,
+ 0,
+ pv->desc_wfc[8],
+ CUDA_C_64F,
+ cublas_res.cublasmp_grid,
+ &desc_Psi);
+
+ // Eij descriptor: MUST use nband x nband physically, to match pv->desc_Eij expectations
+ cublasMpMatrixDescriptorCreate(nband,
+ nband,
+ pv->desc_Eij[4],
+ pv->desc_Eij[5],
+ 0,
+ 0,
+ pv->desc_Eij[8],
+ CUDA_C_64F,
+ cublas_res.cublasmp_grid,
+ &desc_Eij);
+
+ size_t ws_dev = 0, ws_host = 0;
+ void *d_work = nullptr, *h_work = nullptr;
+
+ // 4. GEMM 1: Tmp1 = H * Psi
+ cublasMpGemm_bufferSize(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ pv->desc[2],
+ pv->desc_wfc[3],
+ pv->desc[3],
+ &alpha,
+ d_H,
+ 1,
+ 1,
+ desc_H,
+ d_Psi,
+ 1,
+ 1,
+ desc_Psi,
+ &beta,
+ d_Tmp1,
+ 1,
+ 1,
+ desc_Psi,
+ CUBLAS_COMPUTE_64F,
+ &ws_dev,
+ &ws_host);
+
+ cudaMallocAsync(&d_work, ws_dev, cublas_res.stream);
+ h_work = malloc(ws_host);
+
+ cublasMpGemm(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ pv->desc[2],
+ pv->desc_wfc[3],
+ pv->desc[3],
+ &alpha,
+ d_H,
+ 1,
+ 1,
+ desc_H,
+ d_Psi,
+ 1,
+ 1,
+ desc_Psi,
+ &beta,
+ d_Tmp1,
+ 1,
+ 1,
+ desc_Psi,
+ CUBLAS_COMPUTE_64F,
+ d_work,
+ ws_dev,
+ h_work,
+ ws_host);
+
+ cudaFreeAsync(d_work, cublas_res.stream);
+ free(h_work);
+
+ // 5. GEMM 2: Eij = Psi^H * Tmp1
+ cublasMpGemm_bufferSize(cublas_res.cublasmp_handle,
+ CUBLAS_OP_C,
+ CUBLAS_OP_N,
+ pv->desc_wfc[3],
+ pv->desc_wfc[3],
+ pv->desc_wfc[2],
+ &alpha,
+ d_Psi,
+ 1,
+ 1,
+ desc_Psi,
+ d_Tmp1,
+ 1,
+ 1,
+ desc_Psi,
+ &beta,
+ d_Eij,
+ 1,
+ 1,
+ desc_Eij,
+ CUBLAS_COMPUTE_64F,
+ &ws_dev,
+ &ws_host);
+
+ cudaMallocAsync(&d_work, ws_dev, cublas_res.stream);
+ h_work = malloc(ws_host);
+
+ cublasMpGemm(cublas_res.cublasmp_handle,
+ CUBLAS_OP_C,
+ CUBLAS_OP_N,
+ pv->desc_wfc[3],
+ pv->desc_wfc[3],
+ pv->desc_wfc[2],
+ &alpha,
+ d_Psi,
+ 1,
+ 1,
+ desc_Psi,
+ d_Tmp1,
+ 1,
+ 1,
+ desc_Psi,
+ &beta,
+ d_Eij,
+ 1,
+ 1,
+ desc_Eij,
+ CUBLAS_COMPUTE_64F,
+ d_work,
+ ws_dev,
+ h_work,
+ ws_host);
+
+ cudaFreeAsync(d_work, cublas_res.stream);
+ free(h_work);
+
+ // 6. Extract Diagonal directly on GPU
+ // Prepare a zero-initialized buffer on GPU to store the local parts of the diagonal
+ ct::Tensor eii_gpu(ct::DataType::DT_DOUBLE, ct::DeviceType::GpuDevice, ct::TensorShape({nband}));
+ double* d_eii = static_cast(eii_gpu.data());
+ cudaMemsetAsync(d_eii, 0, nband * sizeof(double), cublas_res.stream);
+
+ // Launch the extraction kernel
+ module_rt::gpu::launch_extract_ekb_kernel(reinterpret_cast(d_Eij),
+ d_eii,
+ pv->desc_Eij[8],
+ pv->nloc,
+ pv->desc_Eij[4],
+ pv->dim0,
+ pv->dim1,
+ pv->coord[0],
+ pv->coord[1],
+ nband,
+ cublas_res.stream);
+
+ // 7. CUDA-aware MPI Reduction
+ // VERY IMPORTANT: We must synchronize the stream before passing the GPU pointer
+ // to MPI, because MPI operations are generally synchronous to the CPU thread.
+ cudaStreamSynchronize(cublas_res.stream);
+
+ double* d_ekb = static_cast(ekb.data());
+
+ // Direct GPU-to-GPU reduction using CUDA-aware MPI
+ MPI_Allreduce(d_eii, d_ekb, nband, MPI_DOUBLE, MPI_SUM, pv->comm());
+
+ // 8. Cleanup
+ cublasMpMatrixDescriptorDestroy(desc_H);
+ cublasMpMatrixDescriptorDestroy(desc_Psi);
+ cublasMpMatrixDescriptorDestroy(desc_Eij);
+#endif // __CUBLASMP
}
template
diff --git a/source/source_lcao/module_rt/band_energy.h b/source/source_lcao/module_rt/band_energy.h
index 93c83ccdb0..107cd749e5 100644
--- a/source/source_lcao/module_rt/band_energy.h
+++ b/source/source_lcao/module_rt/band_energy.h
@@ -8,6 +8,7 @@
#include "source_base/module_container/ATen/core/tensor.h" // ct::Tensor
#include "source_basis/module_ao/parallel_orbitals.h"
+#include "source_lcao/module_rt/kernels/cublasmp_context.h"
#include
@@ -38,7 +39,8 @@ void compute_ekb_tensor(const Parallel_Orbitals* pv,
const ct::Tensor& Htmp,
const ct::Tensor& psi_k,
ct::Tensor& ekb,
- std::ofstream& ofs_running);
+ std::ofstream& ofs_running,
+ CublasMpResources& cublas_res);
template
void compute_ekb_tensor_lapack(const Parallel_Orbitals* pv,
diff --git a/source/source_lcao/module_rt/evolve_elec.cpp b/source/source_lcao/module_rt/evolve_elec.cpp
index 6d6eebc2ff..dd7507974a 100644
--- a/source/source_lcao/module_rt/evolve_elec.cpp
+++ b/source/source_lcao/module_rt/evolve_elec.cpp
@@ -41,6 +41,12 @@ void Evolve_elec::solve_psi(const int& istep,
// Control the print of matrix to running_md.log
const int print_matrix = 0;
+ // Multi-GPU support
+ CublasMpResources cublas_res;
+#ifdef __CUBLASMP
+ init_cublasmp_resources(cublas_res, MPI_COMM_WORLD, para_orb.desc);
+#endif
+
for (int ik = 0; ik < nks; ik++)
{
phm->updateHk(ik);
@@ -171,7 +177,8 @@ void Evolve_elec::solve_psi(const int& istep,
propagator,
ofs_running,
print_matrix,
- use_lapack);
+ use_lapack,
+ cublas_res);
ModuleBase::timer::tick("TD_Efficiency", "host_device_comm");
// Need to distribute global psi back to all processes
@@ -237,6 +244,10 @@ void Evolve_elec::solve_psi(const int& istep,
ModuleBase::timer::tick("TD_Efficiency", "evolve_k");
} // end k
+#ifdef __CUBLASMP
+ finalize_cublasmp_resources(cublas_res);
+#endif
+
ModuleBase::timer::tick("Evolve_elec", "solve_psi");
return;
}
diff --git a/source/source_lcao/module_rt/evolve_elec.h b/source/source_lcao/module_rt/evolve_elec.h
index 5d0a8e5455..3c2aa95cf6 100644
--- a/source/source_lcao/module_rt/evolve_elec.h
+++ b/source/source_lcao/module_rt/evolve_elec.h
@@ -12,6 +12,7 @@
#include "source_esolver/esolver_ks_lcao_tddft.h"
#include "source_lcao/hamilt_lcao.h"
#include "source_lcao/module_rt/gather_mat.h" // MPI gathering and distributing functions
+#include "source_lcao/module_rt/kernels/cublasmp_context.h"
#include "source_psi/psi.h"
//-----------------------------------------------------------
@@ -26,64 +27,101 @@
// Print the shape of a Tensor
inline void print_tensor_shape(const ct::Tensor& tensor, const std::string& name)
{
- std::cout << "Shape of " << name << ": [";
+ GlobalV::ofs_running << "Shape of " << name << ": [";
for (int i = 0; i < tensor.shape().ndim(); ++i)
{
- std::cout << tensor.shape().dim_size(i);
+ GlobalV::ofs_running << tensor.shape().dim_size(i);
if (i < tensor.shape().ndim() - 1)
{
- std::cout << ", ";
+ GlobalV::ofs_running << ", ";
}
}
- std::cout << "]" << std::endl;
+ GlobalV::ofs_running << "]" << std::endl;
}
// Recursive print function
+template
+inline void print_single_element(const T& val, double threshold)
+{
+ double clean_val = (std::abs(val) < threshold) ? 0.0 : static_cast(val);
+ GlobalV::ofs_running << std::fixed << std::setprecision(6) << clean_val;
+}
+inline void print_single_element(const std::complex& val, double threshold)
+{
+ double re = (std::abs(val.real()) < threshold) ? 0.0 : val.real();
+ double im = (std::abs(val.imag()) < threshold) ? 0.0 : val.imag();
+ GlobalV::ofs_running << std::fixed << std::setprecision(6) << "(" << re << "," << im << ")";
+}
+
template
inline void print_tensor_data_recursive(const T* data,
const std::vector& shape,
const std::vector& strides,
int dim,
std::vector& indices,
- const std::string& name)
+ const std::string& name,
+ const double threshold = 1e-10)
{
if (dim == shape.size())
{
- // Recursion base case: print data when reaching the innermost dimension
- std::cout << name;
+ GlobalV::ofs_running << name;
for (size_t i = 0; i < indices.size(); ++i)
{
- std::cout << "[" << indices[i] << "]";
+ GlobalV::ofs_running << "[" << indices[i] << "]";
}
- std::cout << " = " << *data << std::endl;
+ GlobalV::ofs_running << " = ";
+
+ print_single_element(*data, threshold);
+
+ GlobalV::ofs_running << std::endl;
return;
}
- // Recursively process the current dimension
+
for (int64_t i = 0; i < shape[dim]; ++i)
{
indices[dim] = i;
- print_tensor_data_recursive(data + i * strides[dim], shape, strides, dim + 1, indices, name);
+ print_tensor_data_recursive(data + i * strides[dim], shape, strides, dim + 1, indices, name, threshold);
}
}
-// Generic print function
template
inline void print_tensor_data(const ct::Tensor& tensor, const std::string& name)
{
- const std::vector& shape = tensor.shape().dims();
- const std::vector& strides = tensor.shape().strides();
- const T* data = tensor.data();
+ const ct::Tensor* p_tensor = &tensor;
+ ct::Tensor cpu_tensor_buffer;
+
+ if (tensor.device_type() != ct::DeviceType::CpuDevice)
+ {
+ cpu_tensor_buffer = tensor.to_device();
+ p_tensor = &cpu_tensor_buffer;
+ }
+
+ const std::vector& shape = p_tensor->shape().dims();
+ const std::vector& strides = p_tensor->shape().strides();
+
+ const T* data = p_tensor->data();
+
std::vector indices(shape.size(), 0);
print_tensor_data_recursive(data, shape, strides, 0, indices, name);
}
-// Specialization for std::complex
template <>
inline void print_tensor_data>(const ct::Tensor& tensor, const std::string& name)
{
- const std::vector& shape = tensor.shape().dims();
- const std::vector& strides = tensor.shape().strides();
- const std::complex* data = tensor.data>();
+ const ct::Tensor* p_tensor = &tensor;
+ ct::Tensor cpu_tensor_buffer;
+
+ if (tensor.device_type() != ct::DeviceType::CpuDevice)
+ {
+ cpu_tensor_buffer = tensor.to_device();
+ p_tensor = &cpu_tensor_buffer;
+ }
+
+ const std::vector& shape = p_tensor->shape().dims();
+ const std::vector& strides = p_tensor->shape().strides();
+
+ const std::complex* data = p_tensor->data>();
+
std::vector indices(shape.size(), 0);
print_tensor_data_recursive(data, shape, strides, 0, indices, name);
}
diff --git a/source/source_lcao/module_rt/evolve_psi.cpp b/source/source_lcao/module_rt/evolve_psi.cpp
index 5a2f116fcd..ca2458651f 100644
--- a/source/source_lcao/module_rt/evolve_psi.cpp
+++ b/source/source_lcao/module_rt/evolve_psi.cpp
@@ -129,7 +129,8 @@ void evolve_psi_tensor(const int nband,
int propagator,
std::ofstream& ofs_running,
const int print_matrix,
- const bool use_lapack)
+ const bool use_lapack,
+ CublasMpResources& cublas_res)
{
ModuleBase::TITLE("module_rt", "evolve_psi_tensor");
time_t time_start = time(nullptr);
@@ -221,7 +222,16 @@ void evolve_psi_tensor(const int nband,
{
if (!use_lapack)
{
- half_Hmatrix_tensor(pv, nband, nlocal, Htmp, Stmp, H_laststep, S_laststep, ofs_running, print_matrix);
+ half_Hmatrix_tensor(pv,
+ nband,
+ nlocal,
+ Htmp,
+ Stmp,
+ H_laststep,
+ S_laststep,
+ ofs_running,
+ print_matrix,
+ cublas_res);
}
else if (myid == root_proc)
{
@@ -249,12 +259,13 @@ void evolve_psi_tensor(const int nband,
U_operator,
ofs_running,
print_matrix,
- use_lapack);
+ use_lapack,
+ cublas_res);
// (3) Apply U_operator (psi_k = U * psi_last)
if (!use_lapack)
{
- upsi_tensor(pv, nband, nlocal, U_operator, psi_k_laststep, psi_k, ofs_running, print_matrix);
+ upsi_tensor(pv, nband, nlocal, U_operator, psi_k_laststep, psi_k, ofs_running, print_matrix, cublas_res);
}
else if (myid == root_proc)
{
@@ -264,7 +275,7 @@ void evolve_psi_tensor(const int nband,
// (4) Normalize psi_k
if (!use_lapack)
{
- norm_psi_tensor(pv, nband, nlocal, Stmp, psi_k, ofs_running, print_matrix);
+ norm_psi_tensor(pv, nband, nlocal, Stmp, psi_k, ofs_running, print_matrix, cublas_res);
}
else if (myid == root_proc)
{
@@ -287,7 +298,7 @@ void evolve_psi_tensor(const int nband,
if (!use_lapack)
{
- compute_ekb_tensor(pv, nband, nlocal, Hold, psi_k, ekb, ofs_running);
+ compute_ekb_tensor(pv, nband, nlocal, Hold, psi_k, ekb, ofs_running, cublas_res);
}
else if (myid == root_proc)
{
@@ -323,7 +334,8 @@ template void evolve_psi_tensor(const int nband,
int propagator,
std::ofstream& ofs_running,
const int print_matrix,
- const bool use_lapack);
+ const bool use_lapack,
+ CublasMpResources& cublas_res);
#if ((defined __CUDA) /* || (defined __ROCM) */)
template void evolve_psi_tensor(const int nband,
@@ -338,7 +350,8 @@ template void evolve_psi_tensor(const int nband,
int propagator,
std::ofstream& ofs_running,
const int print_matrix,
- const bool use_lapack);
+ const bool use_lapack,
+ CublasMpResources& cublas_res);
#endif // __CUDA
} // namespace module_rt
diff --git a/source/source_lcao/module_rt/evolve_psi.h b/source/source_lcao/module_rt/evolve_psi.h
index 413b115a0f..34a29d6881 100644
--- a/source/source_lcao/module_rt/evolve_psi.h
+++ b/source/source_lcao/module_rt/evolve_psi.h
@@ -10,6 +10,8 @@
#include "source_base/module_container/ATen/core/tensor_map.h" // TensorMap
#include "source_basis/module_ao/parallel_orbitals.h"
#include "source_lcao/hamilt_lcao.h"
+#include "source_lcao/module_rt/evolve_elec.h"
+#include "source_lcao/module_rt/kernels/cublasmp_context.h"
namespace module_rt
{
@@ -39,7 +41,8 @@ void evolve_psi_tensor(const int nband,
int propagator,
std::ofstream& ofs_running,
const int print_matrix,
- const bool use_lapack);
+ const bool use_lapack,
+ CublasMpResources& cublas_res);
} // namespace module_rt
#endif
\ No newline at end of file
diff --git a/source/source_lcao/module_rt/kernels/cublasmp_context.h b/source/source_lcao/module_rt/kernels/cublasmp_context.h
new file mode 100644
index 0000000000..897f1964c2
--- /dev/null
+++ b/source/source_lcao/module_rt/kernels/cublasmp_context.h
@@ -0,0 +1,167 @@
+#ifndef CUBLASMP_CONTEXT_H
+#define CUBLASMP_CONTEXT_H
+
+#ifdef __MPI
+#include
+#endif
+
+#ifdef __CUDA
+#include
+#endif
+
+#ifdef __CUBLASMP
+#include "source_base/global_variable.h"
+#include "source_base/module_device/device.h"
+
+#include
+#include
+#include
+#include
+
+extern "C"
+{
+#include "source_hsolver/module_genelpa/Cblacs.h"
+}
+
+#define LOG_DEBUG(msg) \
+ do \
+ { \
+ if (g_EnableDebugLog) \
+ { \
+ std::cerr << "[DEBUG] " << msg << " (at " << __func__ << ")" << std::endl; \
+ } \
+ } while (0)
+#endif // __CUBLASMP
+
+// The struct is ALWAYS available.
+struct CublasMpResources
+{
+ bool is_initialized = false;
+
+#ifdef __MPI
+ MPI_Comm mpi_comm = MPI_COMM_NULL;
+#endif
+
+#ifdef __CUDA
+ cudaStream_t stream = nullptr;
+#endif
+
+#ifdef __CUBLASMP
+ ncclComm_t nccl_comm = nullptr;
+
+ cublasMpHandle_t cublasmp_handle = nullptr;
+ cublasMpGrid_t cublasmp_grid = nullptr;
+
+ cusolverMpHandle_t cusolvermp_handle = nullptr;
+ cusolverMpGrid_t cusolvermp_grid = nullptr;
+#endif
+};
+
+// API functions are only visible when cuBLASMp is enabled.
+#ifdef __CUBLASMP
+
+inline void init_cublasmp_resources(CublasMpResources& res, MPI_Comm mpi_comm, const int* desc)
+{
+ if (res.is_initialized)
+ {
+ return;
+ }
+
+ res.mpi_comm = mpi_comm;
+ MPI_Barrier(res.mpi_comm);
+
+ // 1. Get BLACS topology info
+ int cblacs_ctxt = desc[1];
+ int nprows, npcols, myprow, mypcol;
+ Cblacs_gridinfo(cblacs_ctxt, &nprows, &npcols, &myprow, &mypcol);
+
+ GlobalV::ofs_running << "nprows = " << nprows << std::endl;
+ GlobalV::ofs_running << "npcols = " << npcols << std::endl;
+ GlobalV::ofs_running << "myprow = " << myprow << std::endl;
+ GlobalV::ofs_running << "mypcol = " << mypcol << std::endl;
+ GlobalV::ofs_running << "device = " << base_device::DeviceContext::instance().get_device_id() << std::endl;
+
+ int rank, size;
+ MPI_Comm_rank(res.mpi_comm, &rank);
+ MPI_Comm_size(res.mpi_comm, &size);
+
+ int device_id = base_device::DeviceContext::instance().get_device_id();
+ cudaSetDevice(device_id);
+ cudaStreamCreate(&res.stream);
+
+ // 2. Initialize NCCL communicator
+ ncclUniqueId id;
+ if (rank == 0)
+ {
+ ncclGetUniqueId(&id);
+ }
+ // Broadcast the unique NCCL ID to all ranks
+ MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, res.mpi_comm);
+ // Initialize NCCL with the generated ID
+ ncclCommInitRank(&res.nccl_comm, size, id, rank);
+
+ // 3. Initialize cuBLASMp specific resources
+ cublasMpCreate(&res.cublasmp_handle, res.stream);
+ cublasMpGridCreate(nprows, npcols, CUBLASMP_GRID_LAYOUT_ROW_MAJOR, res.nccl_comm, &res.cublasmp_grid);
+
+ // 4. Initialize cuSOLVERMp specific resources
+ cusolverMpCreate(&res.cusolvermp_handle, device_id, res.stream);
+ cusolverMpCreateDeviceGrid(res.cusolvermp_handle,
+ &res.cusolvermp_grid,
+ res.nccl_comm,
+ nprows,
+ npcols,
+ CUSOLVERMP_GRID_MAPPING_ROW_MAJOR);
+
+ res.is_initialized = true;
+}
+
+inline void finalize_cublasmp_resources(CublasMpResources& res)
+{
+ if (!res.is_initialized)
+ {
+ return;
+ }
+
+ if (res.stream)
+ {
+ cudaStreamSynchronize(res.stream);
+ }
+
+ // Destroy cuBLASMp resources
+ if (res.cublasmp_grid)
+ {
+ cublasMpGridDestroy(res.cublasmp_grid);
+ }
+ if (res.cublasmp_handle)
+ {
+ cublasMpDestroy(res.cublasmp_handle);
+ }
+
+ // Destroy cuSOLVERMp resources
+ if (res.cusolvermp_grid)
+ {
+ cusolverMpDestroyGrid(res.cusolvermp_grid);
+ }
+ if (res.cusolvermp_handle)
+ {
+ cusolverMpDestroy(res.cusolvermp_handle);
+ }
+
+ // Destroy NCCL communicator
+ if (res.nccl_comm)
+ {
+ ncclCommDestroy(res.nccl_comm);
+ }
+
+ if (res.stream)
+ {
+ cudaStreamDestroy(res.stream);
+ }
+
+ res.is_initialized = false;
+}
+
+#endif // __CUBLASMP
+
+#endif // CUBLASMP_CONTEXT_H
diff --git a/source/source_lcao/module_rt/kernels/cuda/band_energy_kernel.cu b/source/source_lcao/module_rt/kernels/cuda/band_energy_kernel.cu
new file mode 100644
index 0000000000..74e55610a0
--- /dev/null
+++ b/source/source_lcao/module_rt/kernels/cuda/band_energy_kernel.cu
@@ -0,0 +1,88 @@
+#include "band_energy_kernel.cuh"
+
+#ifdef __CUBLASMP
+namespace module_rt
+{
+namespace gpu
+{
+
+// Device function for global index mapping
+__device__ inline int get_global_index_dev(int local_idx, int block_size, int num_procs, int proc_coord)
+{
+ return (local_idx / block_size) * (num_procs * block_size) + proc_coord * block_size + (local_idx % block_size);
+}
+
+// Kernel to extract the real part of the diagonal elements of Eij
+__global__ void extract_ekb_kernel(const cuDoubleComplex* d_Eij,
+ double* d_eii,
+ int lld,
+ int local_elems,
+ int nb,
+ int dim0,
+ int dim1,
+ int my_prow,
+ int my_pcol,
+ int nband)
+{
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ // Iterate over the allocated local buffer
+ if (idx >= local_elems)
+ {
+ return;
+ }
+
+ // Column-major indexing using the formal Leading Dimension (LLD)
+ int i = idx % lld;
+ int j = idx / lld;
+
+ int grow = get_global_index_dev(i, nb, dim0, my_prow);
+ int gcol = get_global_index_dev(j, nb, dim1, my_pcol);
+
+ // Filter out invalid blocks
+ if (grow >= nband || gcol >= nband)
+ {
+ return;
+ }
+
+ // Extract the diagonal elements
+ if (grow == gcol)
+ {
+ d_eii[grow] = cuCreal(d_Eij[idx]);
+ }
+}
+
+// Wrapper implementation
+void launch_extract_ekb_kernel(const cuDoubleComplex* d_Eij,
+ double* d_eii,
+ int lld,
+ int local_elems,
+ int nb,
+ int dim0,
+ int dim1,
+ int my_prow,
+ int my_pcol,
+ int nband,
+ cudaStream_t stream)
+{
+ if (local_elems > 0)
+ {
+ int threads_per_block = 256;
+ int blocks_per_grid = (local_elems + threads_per_block - 1) / threads_per_block;
+
+ extract_ekb_kernel<<>>(d_Eij,
+ d_eii,
+ lld,
+ local_elems,
+ nb,
+ dim0,
+ dim1,
+ my_prow,
+ my_pcol,
+ nband);
+ }
+}
+
+} // namespace gpu
+} // namespace module_rt
+#endif // __CUBLASMP
diff --git a/source/source_lcao/module_rt/kernels/cuda/band_energy_kernel.cuh b/source/source_lcao/module_rt/kernels/cuda/band_energy_kernel.cuh
new file mode 100644
index 0000000000..2c0198b8f5
--- /dev/null
+++ b/source/source_lcao/module_rt/kernels/cuda/band_energy_kernel.cuh
@@ -0,0 +1,31 @@
+#ifndef BAND_ENERGY_KERNEL_CUH
+#define BAND_ENERGY_KERNEL_CUH
+
+#include
+
+#ifdef __CUBLASMP
+#include
+
+namespace module_rt
+{
+namespace gpu
+{
+
+// Standard C++ wrapper to launch the diagonal extraction kernel
+void launch_extract_ekb_kernel(const cuDoubleComplex* d_Eij,
+ double* d_eii,
+ int lld,
+ int local_elems,
+ int nb,
+ int dim0,
+ int dim1,
+ int my_prow,
+ int my_pcol,
+ int nband,
+ cudaStream_t stream);
+
+} // namespace gpu
+} // namespace module_rt
+#endif // __CUBLASMP
+
+#endif // BAND_ENERGY_KERNEL_CUH
diff --git a/source/source_lcao/module_rt/kernels/cuda/norm_psi_kernel.cu b/source/source_lcao/module_rt/kernels/cuda/norm_psi_kernel.cu
new file mode 100644
index 0000000000..4008df990a
--- /dev/null
+++ b/source/source_lcao/module_rt/kernels/cuda/norm_psi_kernel.cu
@@ -0,0 +1,94 @@
+#include "norm_psi_kernel.cuh"
+
+#include
+
+#ifdef __CUBLASMP
+namespace module_rt
+{
+namespace gpu
+{
+
+// Device function for global index mapping
+__device__ inline int get_global_index_dev(int local_idx, int block_size, int num_procs, int proc_coord)
+{
+ return (local_idx / block_size) * (num_procs * block_size) + proc_coord * block_size + (local_idx % block_size);
+}
+
+// CUDA kernel to normalize the Cij matrix directly on the GPU
+__global__ void normalize_cij_kernel(cuDoubleComplex* d_Cij,
+ int lld,
+ int local_elems,
+ int nb,
+ int dim0,
+ int dim1,
+ int my_prow,
+ int my_pcol,
+ int nband)
+{
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ // We iterate over the entire allocated buffer (local_elems)
+ if (idx >= local_elems)
+ {
+ return;
+ }
+
+ // Column-major indexing using the formal Leading Dimension (LLD)
+ int i = idx % lld;
+ int j = idx / lld;
+
+ int grow = get_global_index_dev(i, nb, dim0, my_prow);
+ int gcol = get_global_index_dev(j, nb, dim1, my_pcol);
+
+ // Filter out the "empty" spaces that are outside the nband x nband logic
+ if (grow >= nband || gcol >= nband)
+ {
+ return;
+ }
+
+ if (grow == gcol)
+ {
+ double val = cuCreal(d_Cij[idx]);
+ if (val < 1e-12)
+ val = 1e-12;
+ d_Cij[idx] = make_cuDoubleComplex(1.0 / sqrt(val), 0.0);
+ }
+ else
+ {
+ d_Cij[idx] = make_cuDoubleComplex(0.0, 0.0);
+ }
+}
+
+// Wrapper function implementation
+void launch_normalize_cij_kernel(cuDoubleComplex* d_Cij,
+ int nrow,
+ int ncol,
+ int nb,
+ int dim0,
+ int dim1,
+ int my_prow,
+ int my_pcol,
+ int nband,
+ cudaStream_t stream)
+{
+ int total_elems = nrow * ncol;
+ if (total_elems > 0)
+ {
+ int threads_per_block = 256;
+ int blocks_per_grid = (total_elems + threads_per_block - 1) / threads_per_block;
+
+ normalize_cij_kernel<<>>(d_Cij,
+ nrow,
+ ncol,
+ nb,
+ dim0,
+ dim1,
+ my_prow,
+ my_pcol,
+ nband);
+ }
+}
+
+} // namespace gpu
+} // namespace module_rt
+#endif // __CUBLASMP
diff --git a/source/source_lcao/module_rt/kernels/cuda/norm_psi_kernel.cuh b/source/source_lcao/module_rt/kernels/cuda/norm_psi_kernel.cuh
new file mode 100644
index 0000000000..af2fb0dd6e
--- /dev/null
+++ b/source/source_lcao/module_rt/kernels/cuda/norm_psi_kernel.cuh
@@ -0,0 +1,30 @@
+#ifndef NORM_PSI_KERNEL_CUH
+#define NORM_PSI_KERNEL_CUH
+
+#include
+
+#ifdef __CUBLASMP
+#include
+
+namespace module_rt
+{
+namespace gpu
+{
+
+// Standard C++ wrapper to launch the normalization kernel
+void launch_normalize_cij_kernel(cuDoubleComplex* d_Cij,
+ int nrow,
+ int ncol,
+ int nb,
+ int dim0,
+ int dim1,
+ int my_prow,
+ int my_pcol,
+ int nband,
+ cudaStream_t stream);
+
+} // namespace gpu
+} // namespace module_rt
+#endif // __CUBLASMP
+
+#endif // NORM_PSI_KERNEL_CUH
diff --git a/source/source_lcao/module_rt/middle_hamilt.cpp b/source/source_lcao/module_rt/middle_hamilt.cpp
index 10f432fbe9..a7b875a270 100644
--- a/source/source_lcao/module_rt/middle_hamilt.cpp
+++ b/source/source_lcao/module_rt/middle_hamilt.cpp
@@ -1,9 +1,12 @@
#include "middle_hamilt.h"
+#include "source_base/global_variable.h"
#include "source_base/module_container/ATen/kernels/blas.h"
#include "source_base/module_device/memory_op.h" // memory operations
#include "source_base/module_external/scalapack_connector.h"
+#include "source_base/timer.h"
+#include
#include
#include
@@ -80,88 +83,125 @@ void half_Hmatrix_tensor(const Parallel_Orbitals* pv,
const ct::Tensor& H_laststep,
const ct::Tensor& S_laststep,
std::ofstream& ofs_running,
- const int print_matrix)
+ const int print_matrix,
+ CublasMpResources& cublas_res)
{
- if (print_matrix)
+#ifdef __CUBLASMP
+ // 1. Validate resources and ensure the grid is properly initialized
+ if (!cublas_res.is_initialized || cublas_res.cublasmp_grid == nullptr)
{
- ofs_running << std::setprecision(10);
- ofs_running << std::endl;
- ofs_running << " H(t+dt) :" << std::endl;
- for (int i = 0; i < pv->nrow; i++)
- {
- const int in = i * pv->ncol;
- for (int j = 0; j < pv->ncol; j++)
- {
- ofs_running << Htmp.data>()[in + j].real() << "+"
- << Htmp.data>()[in + j].imag() << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
- ofs_running << std::endl;
- ofs_running << " H(t):" << std::endl;
- for (int i = 0; i < pv->nrow; i++)
- {
- const int in = i * pv->ncol;
- for (int j = 0; j < pv->ncol; j++)
- {
- ofs_running << H_laststep.data>()[in + j].real() << "+"
- << H_laststep.data>()[in + j].imag() << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
+ return;
}
+ assert(Htmp.device_type() == ct::DeviceType::GpuDevice);
+ assert(Stmp.device_type() == ct::DeviceType::GpuDevice);
+ assert(H_laststep.device_type() == ct::DeviceType::GpuDevice);
+ assert(S_laststep.device_type() == ct::DeviceType::GpuDevice);
+
+ // 2. Extract device pointers
+ void* d_Htmp = static_cast(Htmp.data>());
+ void* d_Stmp = static_cast(Stmp.data>());
+ void* d_H_last = static_cast(const_cast*>(H_laststep.data>()));
+ void* d_S_last = static_cast(const_cast*>(S_laststep.data>()));
+
+ int64_t m_global = pv->desc[2];
+ int64_t n_global = pv->desc[3];
+ int64_t mb = pv->desc[4];
+ int64_t nb = pv->desc[5];
+ int64_t rsrc = pv->desc[6];
+ int64_t csrc = pv->desc[7];
+ int64_t lld = pv->desc[8];
+
+ // 3. Create matrix descriptor
+ cublasMpMatrixDescriptor_t desc_mat;
+ cublasMpMatrixDescriptorCreate(m_global,
+ n_global,
+ mb,
+ nb,
+ rsrc,
+ csrc,
+ lld,
+ CUDA_C_64F,
+ cublas_res.cublasmp_grid,
+ &desc_mat);
+
std::complex alpha = {0.5, 0.0};
std::complex beta = {0.5, 0.0};
- // Perform the operation Htmp = alpha * H_laststep + beta * Htmp
- ScalapackConnector::geadd('N',
- nlocal,
- nlocal,
- alpha,
- H_laststep.data>(),
- 1,
- 1,
- pv->desc,
- beta,
- Htmp.data>(),
- 1,
- 1,
- pv->desc);
-
- // Perform the operation Stmp = alpha * S_laststep + beta * Stmp
- ScalapackConnector::geadd('N',
- nlocal,
- nlocal,
- alpha,
- S_laststep.data>(),
- 1,
- 1,
- pv->desc,
- beta,
- Stmp.data>(),
- 1,
- 1,
- pv->desc);
+ size_t ws_size_dev = 0;
+ size_t ws_size_host = 0;
- if (print_matrix)
- {
- ofs_running << std::endl;
- ofs_running << " H (t+dt/2) :" << std::endl;
- for (int i = 0; i < pv->nrow; i++)
- {
- const int in = i * pv->ncol;
- for (int j = 0; j < pv->ncol; j++)
- {
- ofs_running << Htmp.data>()[in + j].real() << "+"
- << Htmp.data>()[in + j].imag() << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
- }
+ // 4. Query workspace size
+ cublasMpGeadd_bufferSize(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ m_global,
+ n_global,
+ &alpha,
+ d_H_last,
+ 1,
+ 1,
+ desc_mat,
+ &beta,
+ d_Htmp,
+ 1,
+ 1,
+ desc_mat,
+ &ws_size_dev,
+ &ws_size_host);
+
+ void* d_work = nullptr;
+ void* h_work = nullptr;
+
+ cudaMallocAsync(&d_work, ws_size_dev, cublas_res.stream);
+ h_work = malloc(ws_size_host);
+
+ // 5. Compute Htmp = 0.5 * H_last + 0.5 * Htmp
+ cublasMpGeadd(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ m_global,
+ n_global,
+ &alpha,
+ d_H_last,
+ 1,
+ 1,
+ desc_mat,
+ &beta,
+ d_Htmp,
+ 1,
+ 1,
+ desc_mat,
+ d_work,
+ ws_size_dev,
+ h_work,
+ ws_size_host);
+
+ // 6. Compute Stmp = 0.5 * S_last + 0.5 * Stmp
+ cublasMpGeadd(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ m_global,
+ n_global,
+ &alpha,
+ d_S_last,
+ 1,
+ 1,
+ desc_mat,
+ &beta,
+ d_Stmp,
+ 1,
+ 1,
+ desc_mat,
+ d_work,
+ ws_size_dev,
+ h_work,
+ ws_size_host);
+
+ // 7. Synchronize stream and release resources
+ cudaStreamSynchronize(cublas_res.stream);
+
+ cublasMpMatrixDescriptorDestroy(desc_mat);
+ cudaFreeAsync(d_work, cublas_res.stream);
+ free(h_work);
+#endif // __CUBLASMP
}
template
diff --git a/source/source_lcao/module_rt/middle_hamilt.h b/source/source_lcao/module_rt/middle_hamilt.h
index e505185b00..5b59154e4b 100644
--- a/source/source_lcao/module_rt/middle_hamilt.h
+++ b/source/source_lcao/module_rt/middle_hamilt.h
@@ -8,6 +8,7 @@
#include "source_base/module_container/ATen/core/tensor.h" // ct::Tensor
#include "source_basis/module_ao/parallel_orbitals.h"
+#include "source_lcao/module_rt/kernels/cublasmp_context.h"
#include
@@ -43,7 +44,8 @@ void half_Hmatrix_tensor(const Parallel_Orbitals* pv,
const ct::Tensor& H_laststep,
const ct::Tensor& S_laststep,
std::ofstream& ofs_running,
- const int print_matrix);
+ const int print_matrix,
+ CublasMpResources& cublas_res);
template
void half_Hmatrix_tensor_lapack(const Parallel_Orbitals* pv,
diff --git a/source/source_lcao/module_rt/norm_psi.cpp b/source/source_lcao/module_rt/norm_psi.cpp
index f5a9c6c8b4..40884b02df 100644
--- a/source/source_lcao/module_rt/norm_psi.cpp
+++ b/source/source_lcao/module_rt/norm_psi.cpp
@@ -4,6 +4,11 @@
#include "source_base/module_container/ATen/kernels/blas.h"
#include "source_base/module_external/blas_connector.h"
#include "source_base/module_external/scalapack_connector.h"
+#include "source_base/timer.h"
+
+#ifdef __CUBLASMP
+#include "kernels/cuda/norm_psi_kernel.cuh"
+#endif
#include
#include
@@ -231,207 +236,264 @@ void norm_psi_tensor(const Parallel_Orbitals* pv,
const ct::Tensor& Stmp,
ct::Tensor& psi_k,
std::ofstream& ofs_running,
- const int print_matrix)
+ const int print_matrix,
+ CublasMpResources& cublas_res)
{
- assert(pv->nloc_wfc > 0 && pv->nloc > 0);
-
- // Create Tensor objects for temporary data
- ct::Tensor tmp1(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({pv->nloc_wfc}));
- tmp1.zero();
-
- ct::Tensor Cij(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({pv->nloc}));
- Cij.zero();
-
- // Perform matrix multiplication: tmp1 = Stmp * psi_k
- ScalapackConnector::gemm('N',
- 'N',
- nlocal,
- nband,
- nlocal,
- 1.0,
- Stmp.data>(),
- 1,
- 1,
- pv->desc,
- psi_k.data>(),
- 1,
- 1,
- pv->desc_wfc,
- 0.0,
- tmp1.data>(),
- 1,
- 1,
- pv->desc_wfc);
-
- // Perform matrix multiplication: Cij = psi_k^dagger * tmp1
- ScalapackConnector::gemm('C',
- 'N',
- nband,
- nband,
- nlocal,
- 1.0,
- psi_k.data>(),
- 1,
- 1,
- pv->desc_wfc,
- tmp1.data>(),
- 1,
- 1,
- pv->desc_wfc,
- 0.0,
- Cij.data>(),
- 1,
- 1,
- pv->desc_Eij);
-
- if (print_matrix)
+#ifdef __CUBLASMP
+ if (!cublas_res.is_initialized || cublas_res.cublasmp_grid == nullptr)
{
- ofs_running << "original Cij :" << std::endl;
- for (int i = 0; i < pv->ncol; i++)
- {
- const int in = i * pv->ncol;
- for (int j = 0; j < pv->nrow; j++)
- {
- double aa = Cij.data>()[in + j].real();
- double bb = Cij.data>()[in + j].imag();
- if (std::abs(aa) < 1e-8)
- {
- aa = 0.0;
- }
- if (std::abs(bb) < 1e-8)
- {
- bb = 0.0;
- }
- ofs_running << aa << "+" << bb << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
+ return;
}
- int naroc[2] = {0, 0}; // maximum number of row or column
+ void* d_S = static_cast(const_cast*>(Stmp.data>()));
+ void* d_Psi = static_cast(psi_k.data>());
+ int64_t psi_elems = psi_k.NumElements();
- for (int iprow = 0; iprow < pv->dim0; ++iprow)
- {
- for (int ipcol = 0; ipcol < pv->dim1; ++ipcol)
- {
- if (iprow == pv->coord[0] && ipcol == pv->coord[1])
- {
- naroc[0] = pv->nrow;
- naroc[1] = pv->ncol;
- for (int j = 0; j < naroc[1]; ++j)
- {
- int igcol = globalIndex(j, pv->nb, pv->dim1, ipcol);
- if (igcol >= nband)
- {
- continue;
- }
- for (int i = 0; i < naroc[0]; ++i)
- {
- int igrow = globalIndex(i, pv->nb, pv->dim0, iprow);
- if (igrow >= nband)
- {
- continue;
- }
- if (igcol == igrow)
- {
- Cij.data>()[j * naroc[0] + i]
- = {1.0 / sqrt(Cij.data>()[j * naroc[0] + i].real()), 0.0};
- }
- else
- {
- Cij.data>()[j * naroc[0] + i] = {0.0, 0.0};
- }
- }
- }
- }
- } // loop ipcol
- } // loop iprow
+ ct::Tensor Tmp1_gpu(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::GpuDevice, ct::TensorShape({psi_elems}));
+ void* d_Tmp1 = static_cast(Tmp1_gpu.data>());
- // Copy psi_k to tmp1 (using deep copy)
- // tmp1.CopyFrom(psi_k); // Does not work because this will cause tmp1 and psi_k to share the same data
- tmp1 = psi_k; // operator= overload for Tensor class
+ int64_t cij_elems = pv->nloc;
+ ct::Tensor Cij_gpu(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::GpuDevice, ct::TensorShape({cij_elems}));
+ void* d_Cij = static_cast(Cij_gpu.data>());
- // Perform matrix multiplication: psi_k = tmp1 * Cij
- ScalapackConnector::gemm('N',
- 'N',
- nlocal,
- nband,
- nband,
- 1.0,
- tmp1.data>(),
- 1,
- 1,
- pv->desc_wfc,
- Cij.data>(),
- 1,
- 1,
- pv->desc_Eij,
- 0.0,
- psi_k.data>(),
- 1,
- 1,
- pv->desc_wfc);
+ cudaMemsetAsync(d_Cij, 0, cij_elems * sizeof(std::complex), cublas_res.stream);
- if (print_matrix)
- {
- ofs_running << " Cij:" << std::endl;
- for (int i = 0; i < pv->ncol; i++)
- {
- const int in = i * pv->ncol;
- for (int j = 0; j < pv->nrow; j++)
- {
- ofs_running << Cij.data>()[in + j].real() << "+"
- << Cij.data>()[in + j].imag() << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
- ofs_running << std::endl;
- ofs_running << " psi_k:" << std::endl;
- for (int i = 0; i < pv->ncol_bands; i++)
- {
- const int in = i * pv->ncol;
- for (int j = 0; j < pv->ncol; j++)
- {
- double aa = psi_k.data>()[in + j].real();
- double bb = psi_k.data>()[in + j].imag();
- if (std::abs(aa) < 1e-8)
- {
- aa = 0.0;
- }
- if (std::abs(bb) < 1e-8)
- {
- bb = 0.0;
- }
- ofs_running << aa << "+" << bb << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
- ofs_running << " psi_k before normalization:" << std::endl;
- for (int i = 0; i < pv->ncol_bands; i++)
- {
- const int in = i * pv->ncol;
- for (int j = 0; j < pv->ncol; j++)
- {
- double aa = tmp1.data>()[in + j].real();
- double bb = tmp1.data>()[in + j].imag();
- if (std::abs(aa) < 1e-8)
- {
- aa = 0.0;
- }
- if (std::abs(bb) < 1e-8)
- {
- bb = 0.0;
- }
- ofs_running << aa << "+" << bb << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
- ofs_running << std::endl;
- }
+ std::complex alpha = {1.0, 0.0};
+ std::complex beta = {0.0, 0.0};
+
+ cublasMpMatrixDescriptor_t desc_S, desc_Psi, desc_Cij;
+
+ cublasMpMatrixDescriptorCreate(nlocal,
+ nlocal,
+ pv->desc[4],
+ pv->desc[5],
+ 0,
+ 0,
+ pv->desc[8],
+ CUDA_C_64F,
+ cublas_res.cublasmp_grid,
+ &desc_S);
+
+ cublasMpMatrixDescriptorCreate(nlocal,
+ nband,
+ pv->desc_wfc[4],
+ pv->desc_wfc[5],
+ 0,
+ 0,
+ pv->desc_wfc[8],
+ CUDA_C_64F,
+ cublas_res.cublasmp_grid,
+ &desc_Psi);
+
+ cublasMpMatrixDescriptorCreate(nband,
+ nband,
+ pv->desc_Eij[4],
+ pv->desc_Eij[5],
+ 0,
+ 0,
+ pv->desc_Eij[8],
+ CUDA_C_64F,
+ cublas_res.cublasmp_grid,
+ &desc_Cij);
+
+ size_t ws_dev = 0, ws_host = 0;
+ void *d_work = nullptr, *h_work = nullptr;
+
+ // GEMM 1: S * Psi -> Tmp1
+ cublasMpGemm_bufferSize(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ nlocal,
+ nband,
+ nlocal,
+ &alpha,
+ d_S,
+ 1,
+ 1,
+ desc_S,
+ d_Psi,
+ 1,
+ 1,
+ desc_Psi,
+ &beta,
+ d_Tmp1,
+ 1,
+ 1,
+ desc_Psi,
+ CUBLAS_COMPUTE_64F,
+ &ws_dev,
+ &ws_host);
+
+ cudaMallocAsync(&d_work, ws_dev, cublas_res.stream);
+ h_work = malloc(ws_host);
+
+ cublasMpGemm(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ nlocal,
+ nband,
+ nlocal,
+ &alpha,
+ d_S,
+ 1,
+ 1,
+ desc_S,
+ d_Psi,
+ 1,
+ 1,
+ desc_Psi,
+ &beta,
+ d_Tmp1,
+ 1,
+ 1,
+ desc_Psi,
+ CUBLAS_COMPUTE_64F,
+ d_work,
+ ws_dev,
+ h_work,
+ ws_host);
+
+ cudaFreeAsync(d_work, cublas_res.stream);
+ free(h_work);
+
+ // GEMM 2: Psi^H * Tmp1 -> Cij
+ cublasMpGemm_bufferSize(cublas_res.cublasmp_handle,
+ CUBLAS_OP_C,
+ CUBLAS_OP_N,
+ nband,
+ nband,
+ nlocal,
+ &alpha,
+ d_Psi,
+ 1,
+ 1,
+ desc_Psi,
+ d_Tmp1,
+ 1,
+ 1,
+ desc_Psi,
+ &beta,
+ d_Cij,
+ 1,
+ 1,
+ desc_Cij,
+ CUBLAS_COMPUTE_64F,
+ &ws_dev,
+ &ws_host);
+
+ cudaMallocAsync(&d_work, ws_dev, cublas_res.stream);
+ h_work = malloc(ws_host);
+
+ cublasMpGemm(cublas_res.cublasmp_handle,
+ CUBLAS_OP_C,
+ CUBLAS_OP_N,
+ nband,
+ nband,
+ nlocal,
+ &alpha,
+ d_Psi,
+ 1,
+ 1,
+ desc_Psi,
+ d_Tmp1,
+ 1,
+ 1,
+ desc_Psi,
+ &beta,
+ d_Cij,
+ 1,
+ 1,
+ desc_Cij,
+ CUBLAS_COMPUTE_64F,
+ d_work,
+ ws_dev,
+ h_work,
+ ws_host);
+
+ cudaFreeAsync(d_work, cublas_res.stream);
+ free(h_work);
+
+ // Launch GPU In-place Normalization using the C++ wrapper
+ module_rt::gpu::launch_normalize_cij_kernel(reinterpret_cast(d_Cij),
+ pv->desc_Eij[8],
+ pv->nloc,
+ pv->desc_Eij[4],
+ pv->dim0,
+ pv->dim1,
+ pv->coord[0],
+ pv->coord[1],
+ nband,
+ cublas_res.stream);
+
+ // GEMM 3: Tmp1 * Cij -> Psi
+ cudaMemcpyAsync(d_Tmp1,
+ d_Psi,
+ psi_elems * sizeof(std::complex),
+ cudaMemcpyDeviceToDevice,
+ cublas_res.stream);
+
+ cublasMpGemm_bufferSize(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ nlocal,
+ nband,
+ nband,
+ &alpha,
+ d_Tmp1,
+ 1,
+ 1,
+ desc_Psi,
+ d_Cij,
+ 1,
+ 1,
+ desc_Cij,
+ &beta,
+ d_Psi,
+ 1,
+ 1,
+ desc_Psi,
+ CUBLAS_COMPUTE_64F,
+ &ws_dev,
+ &ws_host);
+
+ cudaMallocAsync(&d_work, ws_dev, cublas_res.stream);
+ h_work = malloc(ws_host);
+
+ cublasMpGemm(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ nlocal,
+ nband,
+ nband,
+ &alpha,
+ d_Tmp1,
+ 1,
+ 1,
+ desc_Psi,
+ d_Cij,
+ 1,
+ 1,
+ desc_Cij,
+ &beta,
+ d_Psi,
+ 1,
+ 1,
+ desc_Psi,
+ CUBLAS_COMPUTE_64F,
+ d_work,
+ ws_dev,
+ h_work,
+ ws_host);
+
+ cudaStreamSynchronize(cublas_res.stream);
+
+ cublasMpMatrixDescriptorDestroy(desc_S);
+ cublasMpMatrixDescriptorDestroy(desc_Psi);
+ cublasMpMatrixDescriptorDestroy(desc_Cij);
+
+ cudaFreeAsync(d_work, cublas_res.stream);
+ free(h_work);
+#endif // __CUBLASMP
}
template
diff --git a/source/source_lcao/module_rt/norm_psi.h b/source/source_lcao/module_rt/norm_psi.h
index 9c9435f318..f61aae86ec 100644
--- a/source/source_lcao/module_rt/norm_psi.h
+++ b/source/source_lcao/module_rt/norm_psi.h
@@ -8,6 +8,7 @@
#include "source_base/module_container/ATen/core/tensor.h" // ct::Tensor
#include "source_basis/module_ao/parallel_orbitals.h"
+#include "source_lcao/module_rt/kernels/cublasmp_context.h"
#include
@@ -39,7 +40,8 @@ void norm_psi_tensor(const Parallel_Orbitals* pv,
const ct::Tensor& Stmp,
ct::Tensor& psi_k,
std::ofstream& ofs_running,
- const int print_matrix);
+ const int print_matrix,
+ CublasMpResources& cublas_res);
template
void norm_psi_tensor_lapack(const Parallel_Orbitals* pv,
diff --git a/source/source_lcao/module_rt/propagator.cpp b/source/source_lcao/module_rt/propagator.cpp
index 2854839066..b74b4fb8dd 100644
--- a/source/source_lcao/module_rt/propagator.cpp
+++ b/source/source_lcao/module_rt/propagator.cpp
@@ -54,7 +54,8 @@ void Propagator::compute_propagator_tensor(const int nlocal,
ct::Tensor& U_operator,
std::ofstream& ofs_running,
const int print_matrix,
- const bool use_lapack) const
+ const bool use_lapack,
+ CublasMpResources& cublas_res) const
{
int tag = 0;
switch (ptype)
@@ -62,7 +63,7 @@ void Propagator::compute_propagator_tensor(const int nlocal,
case 0:
if (!use_lapack)
{
- compute_propagator_cn2_tensor(nlocal, Stmp, Htmp, U_operator, ofs_running, print_matrix);
+ compute_propagator_cn2_tensor(nlocal, Stmp, Htmp, U_operator, ofs_running, print_matrix, cublas_res);
}
else
{
@@ -91,7 +92,8 @@ template void Propagator::compute_propagator_tensor(con
ct::Tensor& U_operator,
std::ofstream& ofs_running,
const int print_matrix,
- const bool use_lapack) const;
+ const bool use_lapack,
+ CublasMpResources& cublas_res) const;
#if ((defined __CUDA) /* || (defined __ROCM) */)
template void Propagator::compute_propagator_tensor(const int nlocal,
const ct::Tensor& Stmp,
@@ -100,7 +102,8 @@ template void Propagator::compute_propagator_tensor(con
ct::Tensor& U_operator,
std::ofstream& ofs_running,
const int print_matrix,
- const bool use_lapack) const;
+ const bool use_lapack,
+ CublasMpResources& cublas_res) const;
#endif // __CUDA
#endif // __MPI
} // namespace module_rt
diff --git a/source/source_lcao/module_rt/propagator.h b/source/source_lcao/module_rt/propagator.h
index ca9c7c140b..7bc12e9715 100644
--- a/source/source_lcao/module_rt/propagator.h
+++ b/source/source_lcao/module_rt/propagator.h
@@ -9,6 +9,7 @@
#include "source_base/constants.h"
#include "source_base/module_container/ATen/core/tensor.h" // ct::Tensor
#include "source_basis/module_ao/parallel_orbitals.h"
+#include "source_lcao/module_rt/kernels/cublasmp_context.h"
#include
@@ -139,7 +140,8 @@ class Propagator
ct::Tensor& U_operator,
std::ofstream& ofs_running,
const int print_matrix,
- const bool use_lapack) const;
+ const bool use_lapack,
+ CublasMpResources& cublas_res) const;
#endif // __MPI
private:
@@ -170,7 +172,8 @@ class Propagator
const ct::Tensor& Htmp,
ct::Tensor& U_operator,
std::ofstream& ofs_running,
- const int print_matrix) const;
+ const int print_matrix,
+ CublasMpResources& cublas_res) const;
template
void compute_propagator_cn2_tensor_lapack(const int nlocal,
diff --git a/source/source_lcao/module_rt/propagator_cn2.cpp b/source/source_lcao/module_rt/propagator_cn2.cpp
index 8ef07b0ebb..3f85ed26a0 100644
--- a/source/source_lcao/module_rt/propagator_cn2.cpp
+++ b/source/source_lcao/module_rt/propagator_cn2.cpp
@@ -6,6 +6,7 @@
#include "source_base/module_device/memory_op.h" // memory operations
#include "source_base/module_external/blas_connector.h"
#include "source_base/module_external/scalapack_connector.h"
+#include "source_base/timer.h"
#include "source_io/module_parameter/parameter.h"
#include
@@ -252,249 +253,304 @@ void Propagator::compute_propagator_cn2_tensor(const int nlocal,
const ct::Tensor& Htmp,
ct::Tensor& U_operator,
std::ofstream& ofs_running,
- const int print_matrix) const
+ const int print_matrix,
+ CublasMpResources& cublas_res) const
{
- // (1) copy Htmp to Numerator & Denominator
- ct::Tensor Numerator(ct::DataType::DT_COMPLEX_DOUBLE,
- ct::DeviceType::CpuDevice,
- ct::TensorShape({this->ParaV->nloc}));
- Numerator.zero();
- BlasConnector::copy(this->ParaV->nloc,
- Htmp.data>(),
- 1,
- Numerator.data>(),
- 1);
-
- ct::Tensor Denominator(ct::DataType::DT_COMPLEX_DOUBLE,
- ct::DeviceType::CpuDevice,
- ct::TensorShape({this->ParaV->nloc}));
- Denominator.zero();
- BlasConnector::copy(this->ParaV->nloc,
- Htmp.data>(),
- 1,
- Denominator.data>(),
- 1);
-
- if (print_matrix)
+#ifdef __CUBLASMP
+ // 1. Resource Validation
+ if (!cublas_res.is_initialized || cublas_res.cublasmp_grid == nullptr || cublas_res.cusolvermp_grid == nullptr)
{
- ofs_running << std::endl;
- ofs_running << " S matrix :" << std::endl;
- for (int i = 0; i < this->ParaV->nrow; i++)
- {
- const int in = i * this->ParaV->ncol;
- for (int j = 0; j < this->ParaV->ncol; j++)
- {
- ofs_running << Stmp.data>()[in + j].real() << "+"
- << Stmp.data>()[in + j].imag() << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
- ofs_running << std::endl;
- ofs_running << " H matrix :" << std::endl;
- for (int i = 0; i < this->ParaV->nrow; i++)
- {
- const int in = i * this->ParaV->ncol;
- for (int j = 0; j < this->ParaV->ncol; j++)
- {
- ofs_running << Numerator.data>()[in + j].real() << "+"
- << Numerator.data>()[in + j].imag() << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
+ return;
}
- // ->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- // (2) compute Numerator & Denominator by GEADD
- // Numerator = Stmp - i*para * Htmp; beta1 = - para = -0.25 * this->dt
- // Denominator = Stmp + i*para * Htmp; beta2 = para = 0.25 * this->dt
- std::complex alpha = {1.0, 0.0};
- std::complex beta1 = {0.0, -0.25 * this->dt};
- std::complex beta2 = {0.0, 0.25 * this->dt};
-
- ScalapackConnector::geadd('N',
- nlocal,
- nlocal,
- alpha,
- Stmp.data>(),
- 1,
- 1,
- this->ParaV->desc,
- beta1,
- Numerator.data>(),
- 1,
- 1,
- this->ParaV->desc);
- ScalapackConnector::geadd('N',
- nlocal,
- nlocal,
- alpha,
- Stmp.data>(),
- 1,
- 1,
- this->ParaV->desc,
- beta2,
- Denominator.data>(),
- 1,
- 1,
- this->ParaV->desc);
-
- if (print_matrix)
- {
- ofs_running << " beta=" << beta1 << std::endl;
- ofs_running << " fenmu:" << std::endl;
- for (int i = 0; i < this->ParaV->nrow; i++)
- {
- const int in = i * this->ParaV->ncol;
- for (int j = 0; j < this->ParaV->ncol; j++)
- {
- ofs_running << Denominator.data>()[in + j].real() << "+"
- << Denominator.data>()[in + j].imag() << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
- }
-
- //->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- // (3) Next, invert Denominator
- ct::Tensor ipiv(ct::DataType::DT_INT,
- ct::DeviceType::CpuDevice,
- ct::TensorShape({this->ParaV->nrow + this->ParaV->nb}));
- ipiv.zero();
- int info = 0;
- // (3.1) compute ipiv
- ScalapackConnector::getrf(nlocal,
- nlocal,
- Denominator.data>(),
- 1,
- 1,
- this->ParaV->desc,
- ipiv.data(),
- &info);
-
- // Print ipiv
- if (print_matrix)
- {
- ofs_running << " this->ParaV->nloc = " << this->ParaV->nloc << std::endl;
- ofs_running << " this->ParaV->nrow = " << this->ParaV->nrow << std::endl;
- ofs_running << " this->ParaV->ncol = " << this->ParaV->ncol << std::endl;
- ofs_running << " this->ParaV->nb = " << this->ParaV->nb << std::endl;
- ofs_running << " this->ParaV->get_block_size() = " << this->ParaV->get_block_size() << std::endl;
- ofs_running << " nlocal = " << nlocal << std::endl;
- ofs_running << " ipiv:" << std::endl;
- for (int i = 0; i < this->ParaV->nloc; i++)
- {
- ofs_running << ipiv.data()[i] << " ";
- }
- ofs_running << std::endl;
- }
-
- int lwork = -1;
- int liwotk = -1;
- ct::Tensor work(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::CpuDevice, ct::TensorShape({1}));
- ct::Tensor iwork(ct::DataType::DT_INT, ct::DeviceType::CpuDevice, ct::TensorShape({1}));
- // (3.2) compute work
- ScalapackConnector::getri(nlocal,
- Denominator.data>(),
- 1,
- 1,
- this->ParaV->desc,
- ipiv.data(),
- work.data>(),
- &lwork,
- iwork.data(),
- &liwotk,
- &info);
- lwork = work.data>()[0].real();
- work.resize(ct::TensorShape({lwork}));
- liwotk = iwork.data()[0];
- iwork.resize(ct::TensorShape({liwotk}));
- // (3.3) compute inverse matrix of Denominator
- ScalapackConnector::getri(nlocal,
- Denominator.data>(),
- 1,
- 1,
- this->ParaV->desc,
- ipiv.data(),
- work.data>(),
- &lwork,
- iwork.data(),
- &liwotk,
- &info);
- assert(0 == info);
-
- //->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
-
- // (4) U_operator = Denominator * Numerator;
- ScalapackConnector::gemm('N',
- 'N',
- nlocal,
- nlocal,
- nlocal,
- 1.0,
- Denominator.data>(),
- 1,
+ assert(Stmp.device_type() == ct::DeviceType::GpuDevice);
+ assert(Htmp.device_type() == ct::DeviceType::GpuDevice);
+ assert(U_operator.device_type() == ct::DeviceType::GpuDevice);
+
+ // 2. Extract Pointers
+ void* d_S = static_cast(Stmp.data>());
+ void* d_H = static_cast(Htmp.data>());
+ void* d_Num = static_cast(U_operator.data>());
+
+ int64_t len_loc = this->ParaV->nloc;
+
+ // Allocate temporary tensor for denominator matrix
+ ct::Tensor Denominator_gpu(ct::DataType::DT_COMPLEX_DOUBLE, ct::DeviceType::GpuDevice, ct::TensorShape({len_loc}));
+ void* d_Den = static_cast(Denominator_gpu.data>());
+
+ // 3. Matrix Descriptors Creation
+ int64_t m_global = this->ParaV->desc[2];
+ int64_t n_global = this->ParaV->desc[3];
+ int64_t mb = this->ParaV->desc[4];
+ int64_t nb = this->ParaV->desc[5];
+ int64_t rsrc = this->ParaV->desc[6];
+ int64_t csrc = this->ParaV->desc[7];
+ int64_t lld = this->ParaV->desc[8];
+
+ // 3.1 cuBLASMp Descriptor
+ cublasMpMatrixDescriptor_t desc_blas;
+ cublasMpMatrixDescriptorCreate(m_global,
+ n_global,
+ mb,
+ nb,
+ rsrc,
+ csrc,
+ lld,
+ CUDA_C_64F,
+ cublas_res.cublasmp_grid,
+ &desc_blas);
+
+ // 3.2 cuSOLVERMp Descriptor
+ cusolverMpMatrixDescriptor_t desc_solver;
+ cusolverMpCreateMatrixDesc(&desc_solver,
+ cublas_res.cusolvermp_grid,
+ CUDA_C_64F,
+ m_global,
+ n_global,
+ mb,
+ nb,
+ rsrc,
+ csrc,
+ lld);
+
+ // 4. Construct A (Denominator) and B (Numerator) using Geadd
+ std::complex one = {1.0, 0.0};
+ std::complex coef_neg_i = {0.0, -0.25 * this->dt};
+ std::complex coef_pos_i = {0.0, 0.25 * this->dt};
+
+ cudaMemcpyAsync(d_Num, d_S, len_loc * sizeof(std::complex), cudaMemcpyDeviceToDevice, cublas_res.stream);
+ cudaMemcpyAsync(d_Den, d_S, len_loc * sizeof(std::complex), cudaMemcpyDeviceToDevice, cublas_res.stream);
+
+ size_t ws_geadd_dev = 0, ws_geadd_host = 0;
+ cublasMpGeadd_bufferSize(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ m_global,
+ n_global,
+ &coef_neg_i,
+ d_H,
1,
- this->ParaV->desc,
- Numerator.data>(),
1,
- 1,
- this->ParaV->desc,
- 0.0,
- U_operator.data>(),
+ desc_blas,
+ &one,
+ d_Num,
1,
1,
- this->ParaV->desc);
-
- if (print_matrix)
+ desc_blas,
+ &ws_geadd_dev,
+ &ws_geadd_host);
+
+ void *d_work_geadd = nullptr, *h_work_geadd = nullptr;
+ cudaMallocAsync(&d_work_geadd, ws_geadd_dev, cublas_res.stream);
+ h_work_geadd = malloc(ws_geadd_host);
+
+ // B = S - i * (dt/4) * H
+ cublasMpGeadd(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ m_global,
+ n_global,
+ &coef_neg_i,
+ d_H,
+ 1,
+ 1,
+ desc_blas,
+ &one,
+ d_Num,
+ 1,
+ 1,
+ desc_blas,
+ d_work_geadd,
+ ws_geadd_dev,
+ h_work_geadd,
+ ws_geadd_host);
+
+ // A = S + i * (dt/4) * H
+ cublasMpGeadd(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ m_global,
+ n_global,
+ &coef_pos_i,
+ d_H,
+ 1,
+ 1,
+ desc_blas,
+ &one,
+ d_Den,
+ 1,
+ 1,
+ desc_blas,
+ d_work_geadd,
+ ws_geadd_dev,
+ h_work_geadd,
+ ws_geadd_host);
+
+ cudaFreeAsync(d_work_geadd, cublas_res.stream);
+ free(h_work_geadd);
+
+ // 5. QR Factorization of A (Denominator)
+ int64_t tau_size = m_global + nb;
+ void* d_tau;
+ cudaMallocAsync(&d_tau, tau_size * sizeof(std::complex), cublas_res.stream);
+
+ int* d_info;
+ cudaMallocAsync(&d_info, sizeof(int), cublas_res.stream);
+ cudaMemsetAsync(d_info, 0, sizeof(int), cublas_res.stream);
+
+ size_t ws_geqrf_dev = 0, ws_geqrf_host = 0;
+ cusolverMpGeqrf_bufferSize(cublas_res.cusolvermp_handle,
+ m_global,
+ n_global,
+ d_Den,
+ 1,
+ 1,
+ desc_solver,
+ CUDA_C_64F,
+ &ws_geqrf_dev,
+ &ws_geqrf_host);
+
+ void *d_work_geqrf = nullptr, *h_work_geqrf = nullptr;
+ cudaMallocAsync(&d_work_geqrf, ws_geqrf_dev, cublas_res.stream);
+ h_work_geqrf = malloc(ws_geqrf_host);
+
+ cusolverMpGeqrf(cublas_res.cusolvermp_handle,
+ m_global,
+ n_global,
+ d_Den,
+ 1,
+ 1,
+ desc_solver,
+ d_tau,
+ CUDA_C_64F,
+ d_work_geqrf,
+ ws_geqrf_dev,
+ h_work_geqrf,
+ ws_geqrf_host,
+ d_info);
+
+ cudaFreeAsync(d_work_geqrf, cublas_res.stream);
+ free(h_work_geqrf);
+
+ // Check QR Info
+ int h_info = 0;
+ cudaMemcpyAsync(&h_info, d_info, sizeof(int), cudaMemcpyDeviceToHost, cublas_res.stream);
+ cudaStreamSynchronize(cublas_res.stream);
+ if (h_info != 0)
{
- ofs_running << " fenmu^-1:" << std::endl;
- for (int i = 0; i < this->ParaV->nrow; i++)
- {
- const int in = i * this->ParaV->ncol;
- for (int j = 0; j < this->ParaV->ncol; j++)
- {
- ofs_running << Denominator.data>()[in + j].real() << "+"
- << Denominator.data>()[in + j].imag() << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
- ofs_running << " fenzi:" << std::endl;
- for (int i = 0; i < this->ParaV->nrow; i++)
- {
- const int in = i * this->ParaV->ncol;
- for (int j = 0; j < this->ParaV->ncol; j++)
- {
- ofs_running << Numerator.data>()[in + j].real() << "+"
- << Numerator.data>()[in + j].imag() << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
- ofs_running << " U operator:" << std::endl;
- for (int i = 0; i < this->ParaV->nrow; i++)
- {
- const int in = i * this->ParaV->ncol;
- for (int j = 0; j < this->ParaV->ncol; j++)
- {
- double aa = U_operator.data>()[in + j].real();
- double bb = U_operator.data>()[in + j].imag();
- if (std::abs(aa) < 1e-8)
- {
- aa = 0.0;
- }
- if (std::abs(bb) < 1e-8)
- {
- bb = 0.0;
- }
- ofs_running << aa << "+" << bb << "i ";
- }
- ofs_running << std::endl;
- }
+ std::cerr << "CRITICAL: cusolverMpGeqrf failed with Info: " << h_info << std::endl;
+ MPI_Abort(MPI_COMM_WORLD, -1);
}
+
+ // 6. Apply Q^H to B (Numerator)
+ size_t ws_ormqr_dev = 0, ws_ormqr_host = 0;
+ cusolverMpOrmqr_bufferSize(cublas_res.cusolvermp_handle,
+ CUBLAS_SIDE_LEFT,
+ CUBLAS_OP_C,
+ m_global,
+ n_global,
+ n_global,
+ d_Den,
+ 1,
+ 1,
+ desc_solver,
+ d_tau,
+ d_Num,
+ 1,
+ 1,
+ desc_solver,
+ CUDA_C_64F,
+ &ws_ormqr_dev,
+ &ws_ormqr_host);
+
+ void *d_work_ormqr = nullptr, *h_work_ormqr = nullptr;
+ cudaMallocAsync(&d_work_ormqr, ws_ormqr_dev, cublas_res.stream);
+ h_work_ormqr = malloc(ws_ormqr_host);
+
+ cusolverMpOrmqr(cublas_res.cusolvermp_handle,
+ CUBLAS_SIDE_LEFT,
+ CUBLAS_OP_C,
+ m_global,
+ n_global,
+ n_global,
+ d_Den,
+ 1,
+ 1,
+ desc_solver,
+ d_tau,
+ d_Num,
+ 1,
+ 1,
+ desc_solver,
+ CUDA_C_64F,
+ d_work_ormqr,
+ ws_ormqr_dev,
+ h_work_ormqr,
+ ws_ormqr_host,
+ d_info);
+
+ cudaFreeAsync(d_work_ormqr, cublas_res.stream);
+ free(h_work_ormqr);
+
+ // 7. Solve Triangular System (TRSM)
+ size_t ws_trsm_dev = 0, ws_trsm_host = 0;
+ std::complex alpha_trsm = {1.0, 0.0};
+
+ cublasMpTrsm_bufferSize(cublas_res.cublasmp_handle,
+ CUBLAS_SIDE_LEFT,
+ CUBLAS_FILL_MODE_UPPER,
+ CUBLAS_OP_N,
+ CUBLAS_DIAG_NON_UNIT,
+ m_global,
+ n_global,
+ &alpha_trsm,
+ d_Den,
+ 1,
+ 1,
+ desc_blas,
+ d_Num,
+ 1,
+ 1,
+ desc_blas,
+ CUBLAS_COMPUTE_64F,
+ &ws_trsm_dev,
+ &ws_trsm_host);
+
+ void *d_work_trsm = nullptr, *h_work_trsm = nullptr;
+ cudaMallocAsync(&d_work_trsm, ws_trsm_dev, cublas_res.stream);
+ h_work_trsm = malloc(ws_trsm_host);
+
+ cublasMpTrsm(cublas_res.cublasmp_handle,
+ CUBLAS_SIDE_LEFT,
+ CUBLAS_FILL_MODE_UPPER,
+ CUBLAS_OP_N,
+ CUBLAS_DIAG_NON_UNIT,
+ m_global,
+ n_global,
+ &alpha_trsm,
+ d_Den,
+ 1,
+ 1,
+ desc_blas,
+ d_Num,
+ 1,
+ 1,
+ desc_blas,
+ CUBLAS_COMPUTE_64F,
+ d_work_trsm,
+ ws_trsm_dev,
+ h_work_trsm,
+ ws_trsm_host);
+
+ cudaFreeAsync(d_work_trsm, cublas_res.stream);
+ free(h_work_trsm);
+
+ // 8. Cleanup and Final Synchronization
+ cudaStreamSynchronize(cublas_res.stream);
+
+ cublasMpMatrixDescriptorDestroy(desc_blas);
+ cusolverMpDestroyMatrixDesc(desc_solver);
+
+ cudaFreeAsync(d_tau, cublas_res.stream);
+ cudaFreeAsync(d_info, cublas_res.stream);
+#endif // __CUBLASMP
}
template
diff --git a/source/source_lcao/module_rt/upsi.cpp b/source/source_lcao/module_rt/upsi.cpp
index 0982a77426..e0f1b38d75 100644
--- a/source/source_lcao/module_rt/upsi.cpp
+++ b/source/source_lcao/module_rt/upsi.cpp
@@ -2,7 +2,9 @@
#include "source_base/module_container/ATen/kernels/blas.h"
#include "source_base/module_external/scalapack_connector.h"
+#include "source_base/timer.h"
+#include
#include
#include
@@ -93,74 +95,126 @@ void upsi_tensor(const Parallel_Orbitals* pv,
const ct::Tensor& psi_k_laststep,
ct::Tensor& psi_k,
std::ofstream& ofs_running,
- const int print_matrix)
+ const int print_matrix,
+ CublasMpResources& cublas_res)
{
- ScalapackConnector::gemm('N',
- 'N',
- nlocal,
- nband,
- nlocal,
- 1.0,
- U_operator.data>(),
- 1,
- 1,
- pv->desc,
- psi_k_laststep.data>(),
- 1,
- 1,
- pv->desc_wfc,
- 0.0,
- psi_k.data>(),
- 1,
- 1,
- pv->desc_wfc);
-
- if (print_matrix)
+#ifdef __CUBLASMP
+ // 1. Resource validation
+ if (!cublas_res.is_initialized || cublas_res.cublasmp_grid == nullptr)
{
- ofs_running << std::endl;
- ofs_running << " psi_k:" << std::endl;
- for (int i = 0; i < pv->ncol_bands; i++)
- {
- const int in = i * pv->ncol;
- for (int j = 0; j < pv->ncol; j++)
- {
- double aa = psi_k.data>()[in + j].real();
- double bb = psi_k.data>()[in + j].imag();
- if (std::abs(aa) < 1e-8)
- {
- aa = 0.0;
- }
- if (std::abs(bb) < 1e-8)
- {
- bb = 0.0;
- }
- ofs_running << aa << "+" << bb << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
- ofs_running << " psi_k_laststep:" << std::endl;
- for (int i = 0; i < pv->ncol_bands; i++)
- {
- const int in = i * pv->ncol;
- for (int j = 0; j < pv->ncol; j++)
- {
- double aa = psi_k_laststep.data>()[in + j].real();
- double bb = psi_k_laststep.data>()[in + j].imag();
- if (std::abs(aa) < 1e-8)
- {
- aa = 0.0;
- }
- if (std::abs(bb) < 1e-8)
- {
- bb = 0.0;
- }
- ofs_running << aa << "+" << bb << "i ";
- }
- ofs_running << std::endl;
- }
- ofs_running << std::endl;
+ return;
}
+
+ assert(U_operator.device_type() == ct::DeviceType::GpuDevice);
+ assert(psi_k_laststep.device_type() == ct::DeviceType::GpuDevice);
+ assert(psi_k.device_type() == ct::DeviceType::GpuDevice);
+
+ // 2. Extract device pointers
+ void* d_U = static_cast(const_cast*>(U_operator.data>()));
+ void* d_Psi_old
+ = static_cast(const_cast*>(psi_k_laststep.data>()));
+ void* d_Psi_k = static_cast(psi_k.data>());
+
+ // 3. Create matrix descriptor for U operator (N x N)
+ int64_t m_u = pv->desc[2];
+ int64_t n_u = pv->desc[3];
+ int64_t mb_u = pv->desc[4];
+ int64_t nb_u = pv->desc[5];
+ int64_t lld_u = pv->desc[8];
+
+ cublasMpMatrixDescriptor_t desc_U;
+ cublasMpMatrixDescriptorCreate(m_u, n_u, mb_u, nb_u, 0, 0, lld_u, CUDA_C_64F, cublas_res.cublasmp_grid, &desc_U);
+
+ // 4. Create matrix descriptor for Psi (N x nband)
+ int64_t m_psi = pv->desc_wfc[2];
+ int64_t n_psi = pv->desc_wfc[3];
+ int64_t mb_psi = pv->desc_wfc[4];
+ int64_t nb_psi = pv->desc_wfc[5];
+ int64_t lld_psi = pv->desc_wfc[8];
+
+ cublasMpMatrixDescriptor_t desc_Psi;
+ cublasMpMatrixDescriptorCreate(m_psi,
+ n_psi,
+ mb_psi,
+ nb_psi,
+ 0,
+ 0,
+ lld_psi,
+ CUDA_C_64F,
+ cublas_res.cublasmp_grid,
+ &desc_Psi);
+
+ // 5. Query workspace size for GEMM: Psi_k = 1.0 * U * Psi_old + 0.0 * Psi_k
+ std::complex alpha = {1.0, 0.0};
+ std::complex beta = {0.0, 0.0};
+ size_t ws_gemm_dev = 0;
+ size_t ws_gemm_host = 0;
+
+ cublasMpGemm_bufferSize(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ m_u,
+ n_psi,
+ n_u,
+ &alpha,
+ d_U,
+ 1,
+ 1,
+ desc_U,
+ d_Psi_old,
+ 1,
+ 1,
+ desc_Psi,
+ &beta,
+ d_Psi_k,
+ 1,
+ 1,
+ desc_Psi,
+ CUBLAS_COMPUTE_64F,
+ &ws_gemm_dev,
+ &ws_gemm_host);
+
+ void* d_work = nullptr;
+ void* h_work = nullptr;
+
+ cudaMallocAsync(&d_work, ws_gemm_dev, cublas_res.stream);
+ h_work = malloc(ws_gemm_host);
+
+ // 6. Execute GEMM
+ cublasMpGemm(cublas_res.cublasmp_handle,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ m_u,
+ n_psi,
+ n_u,
+ &alpha,
+ d_U,
+ 1,
+ 1,
+ desc_U,
+ d_Psi_old,
+ 1,
+ 1,
+ desc_Psi,
+ &beta,
+ d_Psi_k,
+ 1,
+ 1,
+ desc_Psi,
+ CUBLAS_COMPUTE_64F,
+ d_work,
+ ws_gemm_dev,
+ h_work,
+ ws_gemm_host);
+
+ // 7. Synchronize and clean up resources
+ cudaStreamSynchronize(cublas_res.stream);
+
+ cublasMpMatrixDescriptorDestroy(desc_U);
+ cublasMpMatrixDescriptorDestroy(desc_Psi);
+ cudaFreeAsync(d_work, cublas_res.stream);
+ free(h_work);
+#endif // __CUBLASMP
}
template
diff --git a/source/source_lcao/module_rt/upsi.h b/source/source_lcao/module_rt/upsi.h
index 6cf0976840..e5404a0622 100644
--- a/source/source_lcao/module_rt/upsi.h
+++ b/source/source_lcao/module_rt/upsi.h
@@ -9,6 +9,7 @@
#include "source_base/module_container/ATen/core/tensor.h" // ct::Tensor
#include "source_basis/module_ao/parallel_orbitals.h"
+#include "source_lcao/module_rt/kernels/cublasmp_context.h"
#include
@@ -42,7 +43,8 @@ void upsi_tensor(const Parallel_Orbitals* pv,
const ct::Tensor& psi_k_laststep,
ct::Tensor& psi_k,
std::ofstream& ofs_running,
- const int print_matrix);
+ const int print_matrix,
+ CublasMpResources& cublas_res);
template
void upsi_tensor_lapack(const Parallel_Orbitals* pv,