diff --git a/CHANGELOG.md b/CHANGELOG.md index 2555e455d..d4556c30e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added - Added possibility of having variables in exponent. - Added basic type stubs to help with IDE autocompletion and type checking. +- MatrixVariable comparisons (<=, >=, ==) now support numpy's broadcast feature. ### Fixed - Implemented all binary operations between MatrixExpr and GenExpr - Fixed the type of @ matrix operation result from MatrixVariable to MatrixExpr. diff --git a/src/pyscipopt/matrix.pxi b/src/pyscipopt/matrix.pxi index 4c9329f55..8353ed767 100644 --- a/src/pyscipopt/matrix.pxi +++ b/src/pyscipopt/matrix.pxi @@ -28,15 +28,14 @@ def _matrixexpr_richcmp(self, other, op): else: raise NotImplementedError("Can only support constraints with '<=', '>=', or '=='.") - res = np.empty(self.shape, dtype=object) if _is_number(other) or isinstance(other, Expr): + res = np.empty(self.shape, dtype=object) res.flat = [_richcmp(i, other, op) for i in self.flat] elif isinstance(other, np.ndarray): - if self.shape != other.shape: - raise ValueError("Shapes do not match for comparison.") - - res.flat = [_richcmp(i, j, op) for i, j in zip(self.flat, other.flat)] + out = np.broadcast(self, other) + res = np.empty(out.shape, dtype=object) + res.flat = [_richcmp(i, j, op) for i, j in out] else: raise TypeError(f"Unsupported type {type(other)}") diff --git a/tests/test_matrix_variable.py b/tests/test_matrix_variable.py index a251515cc..27f549000 100644 --- a/tests/test_matrix_variable.py +++ b/tests/test_matrix_variable.py @@ -1,10 +1,4 @@ import operator -import pdb -import pprint -import pytest -from pyscipopt import Model, Variable, log, exp, cos, sin, sqrt -from pyscipopt import Expr, MatrixExpr, MatrixVariable, MatrixExprCons, MatrixConstraint, ExprCons -from pyscipopt.scip import GenExpr from time import time import numpy as np @@ -22,10 +16,10 @@ cos, exp, log, - quicksum, sin, sqrt, ) +from pyscipopt.scip import GenExpr def test_catching_errors(): @@ -525,3 +519,16 @@ def test_matrix_matmul_return_type(): y = m.addMatrixVar((2, 3)) z = m.addMatrixVar((3, 4)) assert type(y @ z) is MatrixExpr + + +def test_broadcast(): + # test #1065 + m = Model() + x = m.addMatrixVar((2, 3), ub=10) + + m.addMatrixCons(x == np.zeros((2, 1))) + + m.setObjective(x.sum(), "maximize") + m.optimize() + + assert (m.getVal(x) == np.zeros((2, 3))).all()