From cdfc9726c8d3409b26a30c8a87f4aa209426a3d0 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Schmidt Date: Sat, 28 Aug 2021 13:15:41 -0400 Subject: [PATCH 1/2] SplitMatrix can be created from any MatrixBase --- src/quantcore/matrix/split_matrix.py | 50 ++++++++++++++++++------ src/quantcore/matrix/standardized_mat.py | 3 +- tests/test_matrices.py | 7 ++++ 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/src/quantcore/matrix/split_matrix.py b/src/quantcore/matrix/split_matrix.py index 22554fa4..1a9e7ef2 100644 --- a/src/quantcore/matrix/split_matrix.py +++ b/src/quantcore/matrix/split_matrix.py @@ -1,5 +1,5 @@ import warnings -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np from scipy import sparse as sps @@ -9,6 +9,7 @@ from .ext.split import is_sorted, split_col_subsets from .matrix_base import MatrixBase from .sparse_matrix import SparseMatrix +from .standardized_mat import StandardizedMatrix from .util import ( check_matvec_out_shape, check_transpose_matvec_out_shape, @@ -16,6 +17,24 @@ ) +def as_mx(a: Any): + """Convert an array to a corresponding MatrixBase type. + + If the input is already a MatrixBase, return untouched. + If the input is sparse, return a SparseMatrix. + If the input is a numpy array, return a DenseMatrix. + Raise an error is input is another type. + """ + if isinstance(a, (MatrixBase, StandardizedMatrix)): + return a + elif sps.issparse(a): + return SparseMatrix(a) + elif isinstance(a, np.ndarray): + return DenseMatrix(a) + else: + raise ValueError(f"Cannot convert type {type(a)} to Matrix.") + + def split_sparse_and_dense_parts( arg1: sps.csc_matrix, threshold: float = 0.1 ) -> Tuple[DenseMatrix, SparseMatrix, np.ndarray, np.ndarray]: @@ -126,24 +145,29 @@ def __init__( matrices: List[Union[DenseMatrix, SparseMatrix, CategoricalMatrix]], indices: Optional[List[np.ndarray]] = None, ): + flatten_matrices = [] # First check that all matrices are valid types - for _, mat in enumerate(matrices): + for mat in matrices: if not isinstance(mat, MatrixBase): raise ValueError( "Expected all elements of matrices to be subclasses of MatrixBase." ) if isinstance(mat, SplitMatrix): - raise ValueError("Elements of matrices cannot be SplitMatrix.") + # Flatten out the SplitMatrix + for imat in mat.matrices: + flatten_matrices.append(imat) + else: + flatten_matrices.append(mat) # Now that we know these are all MatrixBase, we can check consistent # shapes and dtypes. - self.dtype = matrices[0].dtype - n_row = matrices[0].shape[0] - for i, mat in enumerate(matrices): + self.dtype = flatten_matrices[0].dtype + n_row = flatten_matrices[0].shape[0] + for i, mat in enumerate(flatten_matrices): if mat.dtype != self.dtype: warnings.warn( "Matrices do not all have the same dtype. Dtypes are " - f"{[elt.dtype for elt in matrices]}." + f"{[elt.dtype for elt in flatten_matrices]}." ) if not mat.shape[0] == n_row: raise ValueError( @@ -151,13 +175,15 @@ def __init__( f"but the first matrix has first dimension {n_row} and matrix {i} has " f"first dimension {mat.shape[0]}." ) - if len(mat.shape) != 2: - raise ValueError("All matrices should be two dimensional.") + if mat.ndim == 1: + flatten_matrices[i] = mat[:, np.newaxis] + elif mat.ndim > 2: + raise ValueError("All matrices should be at most two dimensional.") if indices is None: indices = [] current_idx = 0 - for mat in matrices: + for mat in flatten_matrices: indices.append( np.arange(current_idx, current_idx + mat.shape[1], dtype=np.int64) ) @@ -183,14 +209,14 @@ def __init__( assert isinstance(indices, list) - for i, (mat, idx) in enumerate(zip(matrices, indices)): + for i, (mat, idx) in enumerate(zip(flatten_matrices, indices)): if not mat.shape[1] == len(idx): raise ValueError( f"Element {i} of indices should should have length {mat.shape[1]}, " f"but it has shape {idx.shape}" ) - filtered_mats, filtered_idxs = _filter_out_empty(matrices, indices) + filtered_mats, filtered_idxs = _filter_out_empty(flatten_matrices, indices) combined_matrices, combined_indices = _combine_matrices( filtered_mats, filtered_idxs ) diff --git a/src/quantcore/matrix/standardized_mat.py b/src/quantcore/matrix/standardized_mat.py index 16e1c75e..fe5d5cde 100644 --- a/src/quantcore/matrix/standardized_mat.py +++ b/src/quantcore/matrix/standardized_mat.py @@ -3,7 +3,8 @@ import numpy as np from scipy import sparse as sps -from . import MatrixBase, SparseMatrix +from .matrix_base import MatrixBase +from .sparse_matrix import SparseMatrix from .util import ( check_transpose_matvec_out_shape, set_up_rows_or_cols, diff --git a/tests/test_matrices.py b/tests/test_matrices.py index cadf0407..412e0e9d 100644 --- a/tests/test_matrices.py +++ b/tests/test_matrices.py @@ -579,3 +579,10 @@ def test_pandas_to_matrix(): # was being changed in place. assert df["cl_obj"].dtype == object assert df["ds"].dtype == np.float64 + + +@pytest.mark.parametrize("mat", get_all_matrix_base_subclass_mats()) +def test_split_matrix_creation(mat): + sm = mx.SplitMatrix(matrices=[mat, mat]) + assert sm.shape[0] == mat.shape[0] + assert sm.shape[1] == 2 * mat.shape[1] From 6414d5530e95a15b015315a7beeaf282d0d70bfc Mon Sep 17 00:00:00 2001 From: Marc-Antoine Schmidt Date: Sat, 28 Aug 2021 13:53:08 -0400 Subject: [PATCH 2/2] modified test for splitmatrix from 1d --- tests/test_split_matrix.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_split_matrix.py b/tests/test_split_matrix.py index dadc75df..f79f7d0d 100644 --- a/tests/test_split_matrix.py +++ b/tests/test_split_matrix.py @@ -5,6 +5,7 @@ import scipy.sparse as sps import quantcore.matrix as mx +from quantcore.matrix.dense_matrix import DenseMatrix from quantcore.matrix.ext.sparse import csr_dense_sandwich from quantcore.matrix.split_matrix import SplitMatrix, split_sparse_and_dense_parts @@ -229,8 +230,9 @@ def check(mat): many_random_tests(check) -def test_oned_dense_mat(): - m1 = mx.CategoricalMatrix(np.random.randint(0, 10, 10)) - m2 = mx.DenseMatrix(np.random.rand(10)) - with pytest.raises(ValueError): - mx.SplitMatrix([m1, m2]) +def test_init_from_1d(): + m1 = DenseMatrix(np.arange(10, dtype=float)) + m2 = DenseMatrix(np.ones(shape=(10, 2), dtype=float)) + + res = SplitMatrix([m1, m2]) + assert res.shape == (10, 3)