diff --git a/source/source_base/kernels/cuda/math_kernel_op.cu b/source/source_base/kernels/cuda/math_kernel_op.cu index 000b98b8f5..c5b0648c49 100644 --- a/source/source_base/kernels/cuda/math_kernel_op.cu +++ b/source/source_base/kernels/cuda/math_kernel_op.cu @@ -2,6 +2,7 @@ #include "source_base/kernels/math_kernel_op.h" #include "source_psi/psi.h" #include "source_base/tool_quit.h" +#include "source_base/module_container/base/third_party/cublas.h" #include #include @@ -175,9 +176,28 @@ void gemv_op::operator()(const char& trans, const int& incy) { cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op"); - CHECK_CUBLAS(cublasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx)); + CHECK_CUBLAS(cublasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy)); } +template <> +void gemv_op::operator()(const char& trans, + const int& m, + const int& n, + const float* alpha, + const float* A, + const int& lda, + const float* X, + const int& incx, + const float* beta, + float* Y, + const int& incy) +{ + cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op"); + CHECK_CUBLAS(cublasSgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy)); +} + + + template <> void gemv_op, base_device::DEVICE_GPU>::operator()(const char& trans, const int& m, @@ -194,7 +214,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const cha cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op"); cuFloatComplex alpha = make_cuFloatComplex(alpha_in->real(), alpha_in->imag()); cuFloatComplex beta = make_cuFloatComplex(beta_in->real(), beta_in->imag()); - CHECK_CUBLAS(cublasCgemv(cublas_handle, cutrans, m, n, &alpha, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta, (cuFloatComplex*)Y, incx)); + CHECK_CUBLAS(cublasCgemv(cublas_handle, cutrans, m, n, &alpha, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta, (cuFloatComplex*)Y, incy)); } template <> @@ -215,7 +235,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const ch cuDoubleComplex beta = make_cuDoubleComplex(beta_in->real(), beta_in->imag()); // icpc and nvcc have some compatible problems // We must use cuDoubleComplex instead of converting std::complex* to cuDoubleComplex* - CHECK_CUBLAS(cublasZgemv(cublas_handle, cutrans, m, n, &alpha, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta, (cuDoubleComplex*)Y, incx)); + CHECK_CUBLAS(cublasZgemv(cublas_handle, cutrans, m, n, &alpha, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta, (cuDoubleComplex*)Y, incy)); } template <> diff --git a/source/source_base/kernels/rocm/math_kernel_op.hip.cu b/source/source_base/kernels/rocm/math_kernel_op.hip.cu index f60cdfc3b1..1b4f30d6b2 100644 --- a/source/source_base/kernels/rocm/math_kernel_op.hip.cu +++ b/source/source_base/kernels/rocm/math_kernel_op.hip.cu @@ -188,7 +188,7 @@ void gemv_op::operator()(const char& trans, const int& incy) { hipblasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op"); - hipblasErrcheck(hipblasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx)); + hipblasErrcheck(hipblasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incy)); } template <> @@ -205,7 +205,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const cha const int& incy) { hipblasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op"); - hipblasErrcheck(hipblasCgemv(cublas_handle, cutrans, m, n, (hipblasComplex*)alpha, (hipblasComplex*)A, lda, (hipblasComplex*)X, incx, (hipblasComplex*)beta, (hipblasComplex*)Y, incx)); + hipblasErrcheck(hipblasCgemv(cublas_handle, cutrans, m, n, (hipblasComplex*)alpha, (hipblasComplex*)A, lda, (hipblasComplex*)X, incx, (hipblasComplex*)beta, (hipblasComplex*)Y, incy)); } template <> @@ -222,7 +222,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const ch const int& incy) { hipblasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op"); - hipblasErrcheck(hipblasZgemv(cublas_handle, cutrans, m, n, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)A, lda, (hipblasDoubleComplex*)X, incx, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)Y, incx)); + hipblasErrcheck(hipblasZgemv(cublas_handle, cutrans, m, n, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)A, lda, (hipblasDoubleComplex*)X, incx, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)Y, incy)); } template <> diff --git a/source/source_base/module_container/base/third_party/cublas.h b/source/source_base/module_container/base/third_party/cublas.h index f6653011a4..fabc32e983 100644 --- a/source/source_base/module_container/base/third_party/cublas.h +++ b/source/source_base/module_container/base/third_party/cublas.h @@ -152,7 +152,7 @@ void gemv_batched(cublasHandle_t& handle, const char& trans, const int& m, const { for (int ii = 0; ii < batch_size; ++ii) { // Call the single GEMV for each pair of matrix A[ii] and vector x[ii] - cuBlasConnector::gemv(handle, trans, m, n, alpha, A[ii], lda, x[ii], incy, beta, y[ii], incy); + cuBlasConnector::gemv(handle, trans, m, n, alpha, A[ii], lda, x[ii], incx, beta, y[ii], incy); } } diff --git a/source/source_hsolver/diago_dav_subspace.cpp b/source/source_hsolver/diago_dav_subspace.cpp index 8fa997774b..27c6a5b348 100644 --- a/source/source_hsolver/diago_dav_subspace.cpp +++ b/source/source_hsolver/diago_dav_subspace.cpp @@ -77,7 +77,7 @@ Diago_DavSubspace::Diago_DavSubspace(const std::vector& precond if (this->device == base_device::GpuDevice) { resmem_real_op()(this->d_precondition, nbasis_in); - // syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition.data(), nbasis_in); + syncmem_var_h2d_op()(this->d_precondition, this->precondition.data(), nbasis_in); resmem_complex_op()(this->d_scc, this->nbase_x * this->nbase_x); resmem_real_op()(this->d_eigenvalue, this->nbase_x); } @@ -295,6 +295,8 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, } } + if (notconv > 1){ + #ifdef __DSP ModuleBase::gemm_op_mt() #else @@ -313,6 +315,28 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, this->zero, psi_iter + (nbase) * this->dim, this->dim); + } else + { + +#ifdef __DSP + ModuleBase::gemv_op_mt() +#else + ModuleBase::gemv_op() +#endif + ('N', + this->dim, // m: row of A + nbase, // n: col of A + this->one, // alpha + hpsi, // A dim * nbase + this->dim, // LDA: if(N) max(1,m) + vcc, // X nbase + 1, // incx + this->zero, // beta + psi_iter + (nbase) * this->dim, // Y dim + 1 // incy + ); + } + // Eigenvalues operation section Real* e_temp_hd = eigenvalue_iter->data(); @@ -325,6 +349,8 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, // vcc = - vcc * eigenvalue ModuleBase::matrix_mul_vector_op()(nbase, notconv, vcc, this->nbase_x, e_temp_hd, -1.0, vcc, this->nbase_x); + if (notconv > 1){ + #ifdef __DSP ModuleBase::gemm_op_mt() #else @@ -343,6 +369,26 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, this->one, psi_iter + nbase * this->dim, this->dim); + } else + { +#ifdef __DSP + ModuleBase::gemv_op_mt() +#else + ModuleBase::gemv_op() +#endif + ('N', + this->dim, // m: row of A + nbase, // n: col of A + this->one, // alpha + spsi, // A dim * nbase + this->dim, // LDA: if(N) max(1,m) + vcc, // X nbase + 1, // incx + this->one, // beta + psi_iter + nbase * this->dim, // Y dim + 1 // incy + ); + } // Precondition section #if defined(__CUDA) || defined(__ROCM) @@ -413,6 +459,7 @@ void Diago_DavSubspace::cal_elem(const int& dim, { ModuleBase::timer::tick("Diago_DavSubspace", "cal_elem"); + if (notconv > 1){ #ifdef __DSP ModuleBase::gemm_op_mt() #else @@ -451,6 +498,46 @@ void Diago_DavSubspace::cal_elem(const int& dim, &scc[nbase * this->nbase_x], this->nbase_x); + } else { + +#ifdef __DSP + ModuleBase::gemv_op_mt() +#else + ModuleBase::gemv_op() +#endif + ('C', + this->dim, // m: row of A + nbase + notconv, // n: col of A + this->one, // alpha + psi_iter, // A dim * nbase + this->dim, // LDA: if(N) max(1,m) + &hpsi[nbase * this->dim], // X nbase + 1, // incx + this->zero, // beta + &hcc[nbase * this->nbase_x], // Y dim + 1 // incy + ); +#ifdef __DSP + ModuleBase::gemv_op_mt() +#else + ModuleBase::gemv_op() +#endif + ('C', + this->dim, // m: row of A + nbase + notconv, // n: col of A + this->one, // alpha + psi_iter, // A dim * nbase + this->dim, // LDA: if(N) max(1,m) + spsi + nbase * this->dim, // X nbase + 1, // incx + this->zero, // beta + &scc[nbase * this->nbase_x], // Y dim + 1 // incy + ); + + } + + #ifdef __MPI if (this->diag_comm.nproc > 1) { diff --git a/source/source_hsolver/diago_david.cpp b/source/source_hsolver/diago_david.cpp index bf390104af..ef4ba67cf3 100644 --- a/source/source_hsolver/diago_david.cpp +++ b/source/source_hsolver/diago_david.cpp @@ -351,7 +351,24 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, // basis[nbase] = hpsi * vc_ev_vector = hpsi*vcc // basis' = vc_ev_vector' * hpsi' // (dim, notconv) (dim, nbase) (nbase, notconv) - ModuleBase::gemm_op()('N', + if (notconv == 1){ + //Reuse gemv for vector case to avoid potential bug using gemm call with n=1 + ModuleBase::gemv_op()('N', + dim, // m: row of A + nbase, // n: col of A + this->one, // alpha + hpsi, // A dim * nbase + dim, // LDA: if(N) max(1,m) + vc_ev_vector, // X nbase + 1, // incx + this->zero, // beta + basis + dim * nbase, // Y dim + 1 // incy + ); + + }else + { + ModuleBase::gemm_op()('N', 'N', dim, // m: row of A,C notconv, // n: col of B,C @@ -364,7 +381,8 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, this->zero, // belta basis + dim * nbase, // C dim * notconv dim // LDC: if(N) max(1, m) - ); + ); + } //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // for (int m = 0; m < notconv; m++) @@ -411,20 +429,37 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, // = (H - lambda * S) * psi * vcc // = (H - lambda * S) * psi_new //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - ModuleBase::gemm_op()('N', - 'N', - dim, // m: row of A,C - notconv, // n: col of B,C - nbase, // k: col of A, row of B - this->one, // alpha - spsi, // A - dim, // LDA: if(N) max(1,m) if(T) max(1,k) - vc_ev_vector, // B - nbase, // LDB: if(N) max(1,k) if(T) max(1,n) - this->one, // belta - basis + dim * nbase, // C dim * notconv - dim // LDC: if(N) max(1, m) - ); + if (notconv == 1){ + //Use gemv for vector case to avoid potential bug using gemm call with n=1 + ModuleBase::gemv_op()('N', + dim, // m: row of A + nbase, // n: col of A + this->one, // alpha + spsi, // A dim * nbase + dim, // LDA: if(N) max(1,m) + vc_ev_vector, // X nbase + 1, // incx + this->one, // beta + basis + dim * nbase, // Y dim + 1 //incy + ); + } else + { + ModuleBase::gemm_op()('N', + 'N', + dim, // m: row of A,C + notconv, // n: col of B,C + nbase, // k: col of A, row of B + this->one, // alpha + spsi, // A + dim, // LDA: if(N) max(1,m) if(T) max(1,k) + vc_ev_vector, // B + nbase, // LDB: if(N) max(1,k) if(T) max(1,n) + this->one, // beta + basis + dim * nbase, // C dim * notconv + dim // LDC: if(N) max(1, m) + ); + } //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // Preconditioning @@ -478,20 +513,37 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, // first nbase bands psi* dot notconv bands spsi to prepare lagrange_matrix // calculate the square matrix for future lagranges - ModuleBase::gemm_op()('C', - 'N', - nbase, // m: row of A,C - notconv, // n: col of B,C - dim, // k: col of A, row of B - this->one, // alpha - basis, // A - dim, // LDA: if(N) max(1,m) if(T) max(1,k) - &spsi[nbase * dim], // B - dim, // LDB: if(N) max(1,k) if(T) max(1,n) - this->zero, // belta - lagrange, // C - nbase + notconv // LDC: if(N) max(1, m) - ); + if (notconv == 1){ + //Use gemv for vector case to avoid potential bug using gemm call with n=1 + ModuleBase::gemv_op()('C', + dim, // m: row of A + nbase, // n: col of A + this->one, // alpha + basis, // A dim * nbase + dim, // LDA: if(N) max(1,m) + &spsi[nbase * dim], // X dim + 1, // incx + this->zero, // beta + lagrange, // Y nbase + 1 + ); + } else + { + ModuleBase::gemm_op()('C', + 'N', + nbase, // m: row of A,C + notconv, // n: col of B,C + dim, // k: col of A, row of B + this->one, // alpha + basis, // A + dim, // LDA: if(N) max(1,m) if(T) max(1,k) + &spsi[nbase * dim], // B + dim, // LDB: if(N) max(1,k) if(T) max(1,n) + this->zero, // belta + lagrange, // C + nbase + notconv // LDC: if(N) max(1, m) + ); + } for (int m = 0; m < notconv; m++) {