diff --git a/src/tabmat/ext/dense.pyx b/src/tabmat/ext/dense.pyx index ef6dcb54..84dc27fb 100644 --- a/src/tabmat/ext/dense.pyx +++ b/src/tabmat/ext/dense.pyx @@ -38,7 +38,7 @@ def dense_sandwich(np.ndarray X, floating[:] d, int[:] rows, int[:] cols, int th elif X.flags["F_CONTIGUOUS"]: _denseF_sandwich(rowsp, colsp, Xp, dp, outp, in_n, out_m, m, n, thresh1d, kratio, innerblock) else: - raise Exception() + raise Exception("The matrix X is not contiguous.") return out diff --git a/tests/test_fast_sandwich.py b/tests/test_fast_sandwich.py index 77b53d81..051e5569 100644 --- a/tests/test_fast_sandwich.py +++ b/tests/test_fast_sandwich.py @@ -2,10 +2,13 @@ import pytest import scipy as sp import scipy.sparse +from scipy.sparse import csc_matrix from tabmat.ext.dense import dense_sandwich from tabmat.ext.sparse import sparse_sandwich +from tabmat import DenseMatrix, SparseMatrix, SplitMatrix + @pytest.mark.parametrize("dtype", [np.float64, np.float32]) def test_fast_sandwich_sparse(dtype): @@ -62,6 +65,28 @@ def test_fast_sandwich_dense(): check(A, d, cols) +def test_dense_sandwich_on_non_contiguous(): + """Non-regression test for #208""" + rng = np.random.default_rng(seed=123) + X = rng.standard_normal(size=(100, 20)) + + # Xd wraps a not-contiguous array. + Xd = DenseMatrix(X[:, :10]) + Xs = SparseMatrix(csc_matrix(X[:, 10:])) + Xm = SplitMatrix([Xd, Xs]) + + # Making the sandwich product fail. + with pytest.raises(Exception, match="The matrix X is not contiguous"): + Xm.sandwich(np.ones(X.shape[0])) + + # Xd wraps a copy, which makes the data contiguous. + Xd = DenseMatrix(X[:, :10].copy()) + Xm = SplitMatrix([Xd, Xs]) + + # The sandwich product works without problem here. + Xm.sandwich(np.ones(X.shape[0])) + + def check(A, d, cols): Asub = A[:, cols] true = (Asub.T.multiply(d)).dot(Asub).toarray()