Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/experimental/r_KLU_GLU_matrix_values_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ int main(int argc, char* argv[])
std::cout << "KLU analysis status: " << status << std::endl;
status = KLU->factorize();
std::cout << "KLU factorization status: " << status << std::endl;
matrix_type* L = KLU->getLFactorCsr();
matrix_type* U = KLU->getUFactorCsr();
matrix_type* L = KLU->getLFactor();
matrix_type* U = KLU->getUFactor();
if (L == nullptr)
{
printf("ERROR");
}
index_type* P = KLU->getPOrdering();
index_type* Q = KLU->getQOrdering();
GLU->setupCsr(A, L, U, P, Q);
GLU->setup(A, L, U, P, Q);
status = GLU->solve(vec_rhs, vec_x);
std::cout << "GLU solve status: " << status << std::endl;
// status = KLU->solve(vec_rhs, vec_x);
Expand Down
12 changes: 6 additions & 6 deletions examples/experimental/r_KLU_cusolverrf_redo_factorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ int main(int argc, char* argv[])
std::cout << "KLU solve status: " << status << std::endl;
if (i == 1)
{
L = (ReSolve::matrix::Csr*) KLU->getLFactorCsr();
U = (ReSolve::matrix::Csr*) KLU->getUFactorCsr();
L = (ReSolve::matrix::Csr*) KLU->getLFactor();
U = (ReSolve::matrix::Csr*) KLU->getUFactor();
if (L == nullptr)
{
std::cout << "ERROR: L factor is null\n";
Expand All @@ -163,7 +163,7 @@ int main(int argc, char* argv[])
}
P = KLU->getPOrdering();
Q = KLU->getQOrdering();
Rf->setupCsr(A, L, U, P, Q);
Rf->setup(A, L, U, P, Q);
status_refactor = Rf->refactorize();
std::cout << "Initial Rf refactorization status: " << status_refactor << std::endl;

Expand Down Expand Up @@ -223,15 +223,15 @@ int main(int argc, char* argv[])
<< std::scientific << std::setprecision(16)
<< res_nrm / b_nrm << "\n";

L = (ReSolve::matrix::Csr*) KLU->getLFactorCsr();
U = (ReSolve::matrix::Csr*) KLU->getUFactorCsr();
L = (ReSolve::matrix::Csr*) KLU->getLFactor();
U = (ReSolve::matrix::Csr*) KLU->getUFactor();

if (L != nullptr && U != nullptr)
{
P = KLU->getPOrdering();
Q = KLU->getQOrdering();

Rf->setupCsr(A, L, U, P, Q);
Rf->setup(A, L, U, P, Q);
status_refactor = Rf->refactorize();
std::cout << "Rf refactorization after KLU redo status: " << status_refactor << std::endl;
}
Expand Down
6 changes: 3 additions & 3 deletions examples/experimental/r_KLU_rf_FGMRES_reuse_factorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ int main(int argc, char* argv[])
<< sqrt(vector_handler->dot(vec_r, vec_r, ReSolve::memory::DEVICE)) / norm_b << "\n";
if (i == 1)
{
ReSolve::matrix::Csr* L = (ReSolve::matrix::Csr*) KLU->getLFactorCsr();
ReSolve::matrix::Csr* U = (ReSolve::matrix::Csr*) KLU->getUFactorCsr();
ReSolve::matrix::Csr* L = (ReSolve::matrix::Csr*) KLU->getLFactor();
ReSolve::matrix::Csr* U = (ReSolve::matrix::Csr*) KLU->getUFactor();
if (L == nullptr)
{
std::cout << "ERROR\n";
}
index_type* P = KLU->getPOrdering();
index_type* Q = KLU->getQOrdering();
Rf->setupCsr(A, L, U, P, Q);
Rf->setup(A, L, U, P, Q);
std::cout << "about to set FGMRES" << std::endl;
FGMRES->setRestart(1000);
FGMRES->setMaxit(2000);
Expand Down
6 changes: 3 additions & 3 deletions examples/experimental/r_KLU_rocsolverrf_asym6x6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ int main()

helper.resetSystem(A, vec_rhs, vec_x);
helper.printShortSummary();
ReSolve::matrix::Csr* L = (ReSolve::matrix::Csr*) KLU.getLFactorCsr();
ReSolve::matrix::Csr* U = (ReSolve::matrix::Csr*) KLU.getUFactorCsr();
ReSolve::matrix::Csr* L = (ReSolve::matrix::Csr*) KLU.getLFactor();
ReSolve::matrix::Csr* U = (ReSolve::matrix::Csr*) KLU.getUFactor();
if (L == nullptr || U == nullptr)
{
std::cout << "Factor extraction from KLU failed!\n";
Expand All @@ -196,7 +196,7 @@ int main()
std::cout << rhs_before[i] << (i < n - 1 ? ", " : "");
}
std::cout << "]" << std::endl;
Rf.setupCsr(A, L, U, P, Q, vec_rhs);
Rf.setup(A, L, U, P, Q, vec_rhs);
std::cout << "RocSolverRf setup completed\n";

// Test refactorization with the same matrix (in practice, you'd change matrix values)
Expand Down
12 changes: 6 additions & 6 deletions examples/experimental/r_KLU_rocsolverrf_redo_factorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ int main(int argc, char* argv[])
std::cout << "KLU solve status: " << status << std::endl;
if (i == 1)
{
ReSolve::matrix::Csr* L = (ReSolve::matrix::Csr*) KLU->getLFactorCsr();
ReSolve::matrix::Csr* U = (ReSolve::matrix::Csr*) KLU->getUFactorCsr();
ReSolve::matrix::Csr* L = (ReSolve::matrix::Csr*) KLU->getLFactor();
ReSolve::matrix::Csr* U = (ReSolve::matrix::Csr*) KLU->getUFactor();
index_type* P = KLU->getPOrdering();
index_type* Q = KLU->getQOrdering();
vec_rhs->copyDataFrom(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
Rf->setupCsr(A, L, U, P, Q, vec_rhs);
Rf->setup(A, L, U, P, Q, vec_rhs);
Rf->refactorize();
}
}
Expand Down Expand Up @@ -193,13 +193,13 @@ int main(int argc, char* argv[])
<< std::scientific << std::setprecision(16)
<< res_nrm / b_nrm << "\n";

ReSolve::matrix::Csr* L = (ReSolve::matrix::Csr*) KLU->getLFactorCsr();
ReSolve::matrix::Csr* U = (ReSolve::matrix::Csr*) KLU->getUFactorCsr();
ReSolve::matrix::Csr* L = (ReSolve::matrix::Csr*) KLU->getLFactor();
ReSolve::matrix::Csr* U = (ReSolve::matrix::Csr*) KLU->getUFactor();

index_type* P = KLU->getPOrdering();
index_type* Q = KLU->getQOrdering();

Rf->setupCsr(A, L, U, P, Q, vec_rhs);
Rf->setup(A, L, U, P, Q, vec_rhs);
}
}

Expand Down
6 changes: 3 additions & 3 deletions examples/gluRefactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,15 @@ int gluRefactor(int argc, char* argv[])
std::cout << "KLU solve status: " << status << std::endl;

// Extract factors and configure refactorization solver
matrix::Csr* L = (matrix::Csr*) KLU.getLFactorCsr();
matrix::Csr* U = (matrix::Csr*) KLU.getUFactorCsr();
matrix::Csr* L = (matrix::Csr*) KLU.getLFactor();
matrix::Csr* U = (matrix::Csr*) KLU.getUFactor();
if (L == nullptr || U == nullptr)
{
std::cout << "Factor extraction from KLU failed!\n";
}
index_type* P = KLU.getPOrdering();
index_type* Q = KLU.getQOrdering();
status = Rf.setupCsr(A, L, U, P, Q);
status = Rf.setup(A, L, U, P, Q);
std::cout << "Refactorization setup status: " << status << std::endl;

RESOLVE_RANGE_POP("KLU");
Expand Down
6 changes: 3 additions & 3 deletions examples/gpuRefactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,16 @@ int gpuRefactor(int argc, char* argv[])
if (i == 1)
{
// Extract factors and configure refactorization solver
matrix::Csr* L = (matrix::Csr*) KLU.getLFactorCsr();
matrix::Csr* U = (matrix::Csr*) KLU.getUFactorCsr();
matrix::Csr* L = (matrix::Csr*) KLU.getLFactor();
matrix::Csr* U = (matrix::Csr*) KLU.getUFactor();
if (L == nullptr || U == nullptr)
{
std::cout << "Factor extraction from KLU failed!\n";
}
index_type* P = KLU.getPOrdering();
index_type* Q = KLU.getQOrdering();

Rf.setupCsr(A, L, U, P, Q, vec_rhs);
Rf.setup(A, L, U, P, Q, vec_rhs);

// Setup iterative refinement solver
if (is_iterative_refinement)
Expand Down
47 changes: 2 additions & 45 deletions resolve/LinSolverDirect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace ReSolve
}

/**
* @brief Setup function for LinSolverDirect class.
* @brief Setup function for LinSolverDirect class with CSR data
*
* @param[in] A - matrix to be solved
* @param[in] L - optional lower triangular factor
Expand All @@ -61,33 +61,6 @@ namespace ReSolve
return 0;
}

/**
* @brief Setup function for LinSolverDirect class with CSR data
*
* @param[in] A - matrix to be solved
* @param[in] L - optional lower triangular factor
* @param[in] U - optional upper triangular factor
* @param[in] P - optional row permutation vector
* @param[in] Q - optional column permutation vector
* @param[in] rhs - optional right-hand side vector
*
* @return int - error code, 0 if successful
*/
int LinSolverDirect::setupCsr(matrix::Sparse* A,
matrix::Sparse* /* L */,
matrix::Sparse* /* U */,
index_type* /* P */,
index_type* /* Q */,
vector_type* /* rhs */)
{
if (A == nullptr)
{
return 1;
}
A_ = A;
return 0;
}

/**
* @brief Placeholder function for symbolic factorization.
*/
Expand Down Expand Up @@ -115,29 +88,13 @@ namespace ReSolve
/**
* @brief Placeholder function for lower triangular factor in Csr.
*/
matrix::Sparse* LinSolverDirect::getLFactorCsr()
{
return nullptr;
}

/**
* @brief Placeholder function for upper triangular factor in Csr.
*/
matrix::Sparse* LinSolverDirect::getUFactorCsr()
{
return nullptr;
}

/**
* @brief Placeholder function for lower triangular factor.
*/
matrix::Sparse* LinSolverDirect::getLFactor()
{
return nullptr;
}

/**
* @brief Placeholder function for upper triangular factor.
* @brief Placeholder function for upper triangular factor in Csr.
*/
matrix::Sparse* LinSolverDirect::getUFactor()
{
Expand Down
12 changes: 2 additions & 10 deletions resolve/LinSolverDirect.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,20 @@ namespace ReSolve
public:
LinSolverDirect();
virtual ~LinSolverDirect();
virtual int setup(matrix::Sparse* A = nullptr,

virtual int setup(matrix::Sparse* A,
matrix::Sparse* L = nullptr,
matrix::Sparse* U = nullptr,
index_type* P = nullptr,
index_type* Q = nullptr,
vector_type* rhs = nullptr);

virtual int setupCsr(matrix::Sparse* A,
matrix::Sparse* L = nullptr,
matrix::Sparse* U = nullptr,
index_type* P = nullptr,
index_type* Q = nullptr,
vector_type* rhs = nullptr);

virtual int analyze(); // the same as symbolic factorization
virtual int factorize();
virtual int refactorize();
virtual int solve(vector_type* rhs, vector_type* x) = 0;
virtual int solve(vector_type* x) = 0;

virtual matrix::Sparse* getLFactorCsr();
virtual matrix::Sparse* getUFactorCsr();
virtual matrix::Sparse* getLFactor();
virtual matrix::Sparse* getUFactor();
virtual index_type* getPOrdering();
Expand Down
Loading
Loading