diff --git a/source/source_base/kernels/cuda/math_kernel_op.cu b/source/source_base/kernels/cuda/math_kernel_op.cu
index 000b98b8f5..c5b0648c49 100644
--- a/source/source_base/kernels/cuda/math_kernel_op.cu
+++ b/source/source_base/kernels/cuda/math_kernel_op.cu
@@ -2,6 +2,7 @@
#include "source_base/kernels/math_kernel_op.h"
#include "source_psi/psi.h"
#include "source_base/tool_quit.h"
+#include "source_base/module_container/base/third_party/cublas.h"
#include
#include
@@ -175,9 +176,28 @@ void gemv_op::operator()(const char& trans,
const int& incy)
{
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
- CHECK_CUBLAS(cublasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx));
+ CHECK_CUBLAS(cublasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy));
}
+template <>
+void gemv_op::operator()(const char& trans,
+ const int& m,
+ const int& n,
+ const float* alpha,
+ const float* A,
+ const int& lda,
+ const float* X,
+ const int& incx,
+ const float* beta,
+ float* Y,
+ const int& incy)
+{
+ cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
+ CHECK_CUBLAS(cublasSgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy));
+}
+
+
+
template <>
void gemv_op, base_device::DEVICE_GPU>::operator()(const char& trans,
const int& m,
@@ -194,7 +214,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const cha
cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
cuFloatComplex alpha = make_cuFloatComplex(alpha_in->real(), alpha_in->imag());
cuFloatComplex beta = make_cuFloatComplex(beta_in->real(), beta_in->imag());
- CHECK_CUBLAS(cublasCgemv(cublas_handle, cutrans, m, n, &alpha, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta, (cuFloatComplex*)Y, incx));
+ CHECK_CUBLAS(cublasCgemv(cublas_handle, cutrans, m, n, &alpha, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta, (cuFloatComplex*)Y, incy));
}
template <>
@@ -215,7 +235,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const ch
cuDoubleComplex beta = make_cuDoubleComplex(beta_in->real(), beta_in->imag());
// icpc and nvcc have some compatible problems
// We must use cuDoubleComplex instead of converting std::complex* to cuDoubleComplex*
- CHECK_CUBLAS(cublasZgemv(cublas_handle, cutrans, m, n, &alpha, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta, (cuDoubleComplex*)Y, incx));
+ CHECK_CUBLAS(cublasZgemv(cublas_handle, cutrans, m, n, &alpha, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta, (cuDoubleComplex*)Y, incy));
}
template <>
diff --git a/source/source_base/kernels/rocm/math_kernel_op.hip.cu b/source/source_base/kernels/rocm/math_kernel_op.hip.cu
index f60cdfc3b1..1b4f30d6b2 100644
--- a/source/source_base/kernels/rocm/math_kernel_op.hip.cu
+++ b/source/source_base/kernels/rocm/math_kernel_op.hip.cu
@@ -188,7 +188,7 @@ void gemv_op::operator()(const char& trans,
const int& incy)
{
hipblasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
- hipblasErrcheck(hipblasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx));
+ hipblasErrcheck(hipblasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy));
}
template <>
@@ -205,7 +205,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const cha
const int& incy)
{
hipblasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
- hipblasErrcheck(hipblasCgemv(cublas_handle, cutrans, m, n, (hipblasComplex*)alpha, (hipblasComplex*)A, lda, (hipblasComplex*)X, incx, (hipblasComplex*)beta, (hipblasComplex*)Y, incx));
+ hipblasErrcheck(hipblasCgemv(cublas_handle, cutrans, m, n, (hipblasComplex*)alpha, (hipblasComplex*)A, lda, (hipblasComplex*)X, incx, (hipblasComplex*)beta, (hipblasComplex*)Y, incy));
}
template <>
@@ -222,7 +222,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const ch
const int& incy)
{
hipblasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
- hipblasErrcheck(hipblasZgemv(cublas_handle, cutrans, m, n, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)A, lda, (hipblasDoubleComplex*)X, incx, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)Y, incx));
+ hipblasErrcheck(hipblasZgemv(cublas_handle, cutrans, m, n, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)A, lda, (hipblasDoubleComplex*)X, incx, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)Y, incy));
}
template <>
diff --git a/source/source_base/module_container/base/third_party/cublas.h b/source/source_base/module_container/base/third_party/cublas.h
index f6653011a4..fabc32e983 100644
--- a/source/source_base/module_container/base/third_party/cublas.h
+++ b/source/source_base/module_container/base/third_party/cublas.h
@@ -152,7 +152,7 @@ void gemv_batched(cublasHandle_t& handle, const char& trans, const int& m, const
{
for (int ii = 0; ii < batch_size; ++ii) {
// Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
- cuBlasConnector::gemv(handle, trans, m, n, alpha, A[ii], lda, x[ii], incy, beta, y[ii], incy);
+ cuBlasConnector::gemv(handle, trans, m, n, alpha, A[ii], lda, x[ii], incx, beta, y[ii], incy);
}
}
diff --git a/source/source_hsolver/diago_dav_subspace.cpp b/source/source_hsolver/diago_dav_subspace.cpp
index 8fa997774b..27c6a5b348 100644
--- a/source/source_hsolver/diago_dav_subspace.cpp
+++ b/source/source_hsolver/diago_dav_subspace.cpp
@@ -77,7 +77,7 @@ Diago_DavSubspace::Diago_DavSubspace(const std::vector& precond
if (this->device == base_device::GpuDevice)
{
resmem_real_op()(this->d_precondition, nbasis_in);
- // syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition.data(), nbasis_in);
+ syncmem_var_h2d_op()(this->d_precondition, this->precondition.data(), nbasis_in);
resmem_complex_op()(this->d_scc, this->nbase_x * this->nbase_x);
resmem_real_op()(this->d_eigenvalue, this->nbase_x);
}
@@ -295,6 +295,8 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func,
}
}
+ if (notconv > 1){
+
#ifdef __DSP
ModuleBase::gemm_op_mt()
#else
@@ -313,6 +315,28 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func,
this->zero,
psi_iter + (nbase) * this->dim,
this->dim);
+ } else
+ {
+
+#ifdef __DSP
+ ModuleBase::gemv_op_mt()
+#else
+ ModuleBase::gemv_op()
+#endif
+ ('N',
+ this->dim, // m: row of A
+ nbase, // n: col of A
+ this->one, // alpha
+ hpsi, // A dim * nbase
+ this->dim, // LDA: if(N) max(1,m)
+ vcc, // X nbase
+ 1, // incx
+ this->zero, // beta
+ psi_iter + (nbase) * this->dim, // Y dim
+ 1 // incy
+ );
+ }
+
// Eigenvalues operation section
Real* e_temp_hd = eigenvalue_iter->data();
@@ -325,6 +349,8 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func,
// vcc = - vcc * eigenvalue
ModuleBase::matrix_mul_vector_op()(nbase, notconv, vcc, this->nbase_x, e_temp_hd, -1.0, vcc, this->nbase_x);
+ if (notconv > 1){
+
#ifdef __DSP
ModuleBase::gemm_op_mt()
#else
@@ -343,6 +369,26 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func,
this->one,
psi_iter + nbase * this->dim,
this->dim);
+ } else
+ {
+#ifdef __DSP
+ ModuleBase::gemv_op_mt()
+#else
+ ModuleBase::gemv_op()
+#endif
+ ('N',
+ this->dim, // m: row of A
+ nbase, // n: col of A
+ this->one, // alpha
+ spsi, // A dim * nbase
+ this->dim, // LDA: if(N) max(1,m)
+ vcc, // X nbase
+ 1, // incx
+ this->one, // beta
+ psi_iter + nbase * this->dim, // Y dim
+ 1 // incy
+ );
+ }
// Precondition section
#if defined(__CUDA) || defined(__ROCM)
@@ -413,6 +459,7 @@ void Diago_DavSubspace::cal_elem(const int& dim,
{
ModuleBase::timer::tick("Diago_DavSubspace", "cal_elem");
+ if (notconv > 1){
#ifdef __DSP
ModuleBase::gemm_op_mt()
#else
@@ -451,6 +498,46 @@ void Diago_DavSubspace::cal_elem(const int& dim,
&scc[nbase * this->nbase_x],
this->nbase_x);
+ } else {
+
+#ifdef __DSP
+ ModuleBase::gemv_op_mt()
+#else
+ ModuleBase::gemv_op()
+#endif
+ ('C',
+ this->dim, // m: row of A
+ nbase + notconv, // n: col of A
+ this->one, // alpha
+ psi_iter, // A dim * nbase
+ this->dim, // LDA: if(N) max(1,m)
+ &hpsi[nbase * this->dim], // X nbase
+ 1, // incx
+ this->zero, // beta
+ &hcc[nbase * this->nbase_x], // Y dim
+ 1 // incy
+ );
+#ifdef __DSP
+ ModuleBase::gemv_op_mt()
+#else
+ ModuleBase::gemv_op()
+#endif
+ ('C',
+ this->dim, // m: row of A
+ nbase + notconv, // n: col of A
+ this->one, // alpha
+ psi_iter, // A dim * nbase
+ this->dim, // LDA: if(N) max(1,m)
+ spsi + nbase * this->dim, // X nbase
+ 1, // incx
+ this->zero, // beta
+ &scc[nbase * this->nbase_x], // Y dim
+ 1 // incy
+ );
+
+ }
+
+
#ifdef __MPI
if (this->diag_comm.nproc > 1)
{
diff --git a/source/source_hsolver/diago_david.cpp b/source/source_hsolver/diago_david.cpp
index bf390104af..ef4ba67cf3 100644
--- a/source/source_hsolver/diago_david.cpp
+++ b/source/source_hsolver/diago_david.cpp
@@ -351,7 +351,24 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func,
// basis[nbase] = hpsi * vc_ev_vector = hpsi*vcc
// basis' = vc_ev_vector' * hpsi'
// (dim, notconv) (dim, nbase) (nbase, notconv)
- ModuleBase::gemm_op()('N',
+ if (notconv == 1){
+ //Reuse gemv for vector case to avoid potential bug using gemm call with n=1
+ ModuleBase::gemv_op()('N',
+ dim, // m: row of A
+ nbase, // n: col of A
+ this->one, // alpha
+ hpsi, // A dim * nbase
+ dim, // LDA: if(N) max(1,m)
+ vc_ev_vector, // X nbase
+ 1, // incx
+ this->zero, // beta
+ basis + dim * nbase, // Y dim
+ 1 // incy
+ );
+
+ }else
+ {
+ ModuleBase::gemm_op()('N',
'N',
dim, // m: row of A,C
notconv, // n: col of B,C
@@ -364,7 +381,8 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func,
this->zero, // belta
basis + dim * nbase, // C dim * notconv
dim // LDC: if(N) max(1, m)
- );
+ );
+ }
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// for (int m = 0; m < notconv; m++)
@@ -411,20 +429,37 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func,
// = (H - lambda * S) * psi * vcc
// = (H - lambda * S) * psi_new
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
- ModuleBase::gemm_op()('N',
- 'N',
- dim, // m: row of A,C
- notconv, // n: col of B,C
- nbase, // k: col of A, row of B
- this->one, // alpha
- spsi, // A
- dim, // LDA: if(N) max(1,m) if(T) max(1,k)
- vc_ev_vector, // B
- nbase, // LDB: if(N) max(1,k) if(T) max(1,n)
- this->one, // belta
- basis + dim * nbase, // C dim * notconv
- dim // LDC: if(N) max(1, m)
- );
+ if (notconv == 1){
+ //Use gemv for vector case to avoid potential bug using gemm call with n=1
+ ModuleBase::gemv_op()('N',
+ dim, // m: row of A
+ nbase, // n: col of A
+ this->one, // alpha
+ spsi, // A dim * nbase
+ dim, // LDA: if(N) max(1,m)
+ vc_ev_vector, // X nbase
+ 1, // incx
+ this->one, // beta
+ basis + dim * nbase, // Y dim
+ 1 //incy
+ );
+ } else
+ {
+ ModuleBase::gemm_op()('N',
+ 'N',
+ dim, // m: row of A,C
+ notconv, // n: col of B,C
+ nbase, // k: col of A, row of B
+ this->one, // alpha
+ spsi, // A
+ dim, // LDA: if(N) max(1,m) if(T) max(1,k)
+ vc_ev_vector, // B
+ nbase, // LDB: if(N) max(1,k) if(T) max(1,n)
+ this->one, // beta
+ basis + dim * nbase, // C dim * notconv
+ dim // LDC: if(N) max(1, m)
+ );
+ }
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// Preconditioning
@@ -478,20 +513,37 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func,
// first nbase bands psi* dot notconv bands spsi to prepare lagrange_matrix
// calculate the square matrix for future lagranges
- ModuleBase::gemm_op()('C',
- 'N',
- nbase, // m: row of A,C
- notconv, // n: col of B,C
- dim, // k: col of A, row of B
- this->one, // alpha
- basis, // A
- dim, // LDA: if(N) max(1,m) if(T) max(1,k)
- &spsi[nbase * dim], // B
- dim, // LDB: if(N) max(1,k) if(T) max(1,n)
- this->zero, // belta
- lagrange, // C
- nbase + notconv // LDC: if(N) max(1, m)
- );
+ if (notconv == 1){
+ //Use gemv for vector case to avoid potential bug using gemm call with n=1
+ ModuleBase::gemv_op()('C',
+ dim, // m: row of A
+ nbase, // n: col of A
+ this->one, // alpha
+ basis, // A dim * nbase
+ dim, // LDA: if(N) max(1,m)
+ &spsi[nbase * dim], // X dim
+ 1, // incx
+ this->zero, // beta
+ lagrange, // Y nbase
+ 1
+ );
+ } else
+ {
+ ModuleBase::gemm_op()('C',
+ 'N',
+ nbase, // m: row of A,C
+ notconv, // n: col of B,C
+ dim, // k: col of A, row of B
+ this->one, // alpha
+ basis, // A
+ dim, // LDA: if(N) max(1,m) if(T) max(1,k)
+ &spsi[nbase * dim], // B
+ dim, // LDB: if(N) max(1,k) if(T) max(1,n)
+ this->zero, // belta
+ lagrange, // C
+ nbase + notconv // LDC: if(N) max(1, m)
+ );
+ }
for (int m = 0; m < notconv; m++)
{