diff --git a/CHANGELOG.md b/CHANGELOG.md index e3d9cc3e9..501d26280 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased ### Added +- Added `getBase()` and `setBase()` methods to `LP` class for getting/setting basis status - Added `getMemUsed()`, `getMemTotal()`, and `getMemExternEstim()` methods ### Fixed - Removed `Py_INCREF`/`Py_DECREF` on `Model` in `catchEvent`/`dropEvent` that caused memory leak for imbalanced usage diff --git a/src/pyscipopt/__init__.py b/src/pyscipopt/__init__.py index fafa440ad..170d86e42 100644 --- a/src/pyscipopt/__init__.py +++ b/src/pyscipopt/__init__.py @@ -28,6 +28,7 @@ from pyscipopt.scip import LP as LP from pyscipopt.scip import IISfinder as IISfinder from pyscipopt.scip import PY_SCIP_LPPARAM as SCIP_LPPARAM +from pyscipopt.scip import PY_SCIP_BASESTAT as SCIP_BASESTAT from pyscipopt.scip import readStatistics as readStatistics from pyscipopt.scip import Expr as Expr from pyscipopt.scip import MatrixExpr as MatrixExpr diff --git a/src/pyscipopt/lp.pxi b/src/pyscipopt/lp.pxi index 59e5b2115..a33961190 100644 --- a/src/pyscipopt/lp.pxi +++ b/src/pyscipopt/lp.pxi @@ -531,6 +531,66 @@ cdef class LP: return binds + def getBase(self): + """Returns the basis status of columns and rows. + + Status values are defined in SCIP_BASESTAT: LOWER, BASIC, UPPER, ZERO. + + Returns + ------- + tuple of (list of int, list of int) + Column basis statuses and row basis statuses. + + """ + cdef int ncols = self.ncols() + cdef int nrows = self.nrows() + cdef int* c_cstat = malloc(ncols * sizeof(int)) + cdef int* c_rstat = malloc(nrows * sizeof(int)) + cdef int i + + PY_SCIP_CALL(SCIPlpiGetBase(self.lpi, c_cstat, c_rstat)) + + cstat = [c_cstat[i] for i in range(ncols)] + rstat = [c_rstat[i] for i in range(nrows)] + + free(c_rstat) + free(c_cstat) + + return cstat, rstat + + def setBase(self, cstat, rstat): + """Sets the basis status of columns and rows. + + Status values are defined in SCIP_BASESTAT: LOWER, BASIC, UPPER, ZERO. + + Parameters + ---------- + cstat : list of int + Column basis statuses (length must equal ncols). + rstat : list of int + Row basis statuses (length must equal nrows). + + """ + cdef int ncols = self.ncols() + cdef int nrows = self.nrows() + if len(cstat) != ncols: + raise ValueError(f"cstat has length {len(cstat)}, expected {ncols}") + if len(rstat) != nrows: + raise ValueError(f"rstat has length {len(rstat)}, expected {nrows}") + cdef int* c_cstat = malloc(ncols * sizeof(int)) + cdef int* c_rstat = malloc(nrows * sizeof(int)) + cdef int i + + for i in range(ncols): + c_cstat[i] = cstat[i] + for i in range(nrows): + c_rstat[i] = rstat[i] + + PY_SCIP_CALL(SCIPlpiSetBase(self.lpi, c_cstat, c_rstat)) + + free(c_rstat) + free(c_cstat) + # Parameter Methods def setIntParam(self, param, value): diff --git a/src/pyscipopt/scip.pxd b/src/pyscipopt/scip.pxd index cb756dc1f..8ef69adfe 100644 --- a/src/pyscipopt/scip.pxd +++ b/src/pyscipopt/scip.pxd @@ -1532,6 +1532,8 @@ cdef extern from "scip/scip.h": SCIP_RETCODE SCIPlpiGetPrimalRay(SCIP_LPI* lpi, SCIP_Real* ray) SCIP_RETCODE SCIPlpiGetDualfarkas(SCIP_LPI* lpi, SCIP_Real* dualfarkas) SCIP_RETCODE SCIPlpiGetBasisInd(SCIP_LPI* lpi, int* bind) + SCIP_RETCODE SCIPlpiGetBase(SCIP_LPI* lpi, int* cstat, int* rstat) + SCIP_RETCODE SCIPlpiSetBase(SCIP_LPI* lpi, const int* cstat, const int* rstat) SCIP_RETCODE SCIPlpiGetRealSolQuality(SCIP_LPI* lpi, SCIP_LPSOLQUALITY qualityindicator, SCIP_Real* quality) SCIP_RETCODE SCIPlpiGetIntpar(SCIP_LPI* lpi, SCIP_LPPARAM type, int* ival) SCIP_RETCODE SCIPlpiGetRealpar(SCIP_LPI* lpi, SCIP_LPPARAM type, SCIP_Real* dval) diff --git a/src/pyscipopt/scip.pxi b/src/pyscipopt/scip.pxi index bdea41bd5..871638ebf 100644 --- a/src/pyscipopt/scip.pxi +++ b/src/pyscipopt/scip.pxi @@ -118,6 +118,12 @@ cdef class PY_SCIP_LPPARAM: POLISHING = SCIP_LPPAR_POLISHING REFACTOR = SCIP_LPPAR_REFACTOR +cdef class PY_SCIP_BASESTAT: + LOWER = SCIP_BASESTAT_LOWER + BASIC = SCIP_BASESTAT_BASIC + UPPER = SCIP_BASESTAT_UPPER + ZERO = SCIP_BASESTAT_ZERO + cdef class PY_SCIP_PARAMEMPHASIS: DEFAULT = SCIP_PARAMEMPHASIS_DEFAULT CPSOLVER = SCIP_PARAMEMPHASIS_CPSOLVER diff --git a/src/pyscipopt/scip.pyi b/src/pyscipopt/scip.pyi index 2728cf513..92cc58540 100644 --- a/src/pyscipopt/scip.pyi +++ b/src/pyscipopt/scip.pyi @@ -467,6 +467,7 @@ class LP: def delCols(self, firstcol: Incomplete, lastcol: Incomplete) -> Incomplete: ... def delRows(self, firstrow: Incomplete, lastrow: Incomplete) -> Incomplete: ... def getActivity(self) -> Incomplete: ... + def getBase(self) -> Incomplete: ... def getBasisInds(self) -> Incomplete: ... def getBounds( self, firstcol: Incomplete = ..., lastcol: Incomplete = ... @@ -491,6 +492,7 @@ class LP: def ncols(self) -> Incomplete: ... def nrows(self) -> Incomplete: ... def readLP(self, filename: Incomplete) -> Incomplete: ... + def setBase(self, cstat: Incomplete, rstat: Incomplete) -> Incomplete: ... def setIntParam(self, param: Incomplete, value: Incomplete) -> Incomplete: ... def setRealParam(self, param: Incomplete, value: Incomplete) -> Incomplete: ... def solve(self, dual: Incomplete = ...) -> Incomplete: ... @@ -1801,6 +1803,13 @@ class PY_SCIP_LOCKTYPE: MODEL: ClassVar[int] = ... def __init__(self) -> None: ... +class PY_SCIP_BASESTAT: + LOWER: ClassVar[int] = ... + BASIC: ClassVar[int] = ... + UPPER: ClassVar[int] = ... + ZERO: ClassVar[int] = ... + def __init__(self) -> None: ... + class PY_SCIP_LPPARAM: BARRIERCONVTOL: ClassVar[int] = ... CONDITIONLIMIT: ClassVar[int] = ... diff --git a/tests/test_lp.py b/tests/test_lp.py index 7cc90585d..deb416d58 100644 --- a/tests/test_lp.py +++ b/tests/test_lp.py @@ -1,5 +1,6 @@ from pyscipopt import LP from pyscipopt import SCIP_LPPARAM +from pyscipopt import SCIP_BASESTAT def test_lp(): # create LP instance, minimizing by default @@ -89,3 +90,25 @@ def test_lp(): assert round(myLP.getObjVal() == solval) assert round(5.0 == solval) + + # test basis get/set + binds = myLP.getBasisInds() + assert len(binds) == myLP.nrows() + + cstat, rstat = myLP.getBase() + assert len(cstat) == myLP.ncols() + assert len(rstat) == myLP.nrows() + assert all(s in (SCIP_BASESTAT.LOWER, SCIP_BASESTAT.BASIC, + SCIP_BASESTAT.UPPER, SCIP_BASESTAT.ZERO) for s in cstat) + assert all(s in (SCIP_BASESTAT.LOWER, SCIP_BASESTAT.BASIC, + SCIP_BASESTAT.UPPER) for s in rstat) + + # set the same basis back and re-solve + myLP.setBase(cstat, rstat) + solval2 = myLP.solve() + assert round(solval2, 10) == round(solval, 10) + + # verify basis is preserved after set + cstat2, rstat2 = myLP.getBase() + assert cstat2 == cstat + assert rstat2 == rstat