Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions src/quantcore/matrix/split_matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import numpy as np
from scipy import sparse as sps
Expand All @@ -9,13 +9,32 @@
from .ext.split import is_sorted, split_col_subsets
from .matrix_base import MatrixBase
from .sparse_matrix import SparseMatrix
from .standardized_mat import StandardizedMatrix
from .util import (
check_matvec_out_shape,
check_transpose_matvec_out_shape,
set_up_rows_or_cols,
)


def as_mx(a: Any):
"""Convert an array to a corresponding MatrixBase type.

If the input is already a MatrixBase, return untouched.
If the input is sparse, return a SparseMatrix.
If the input is a numpy array, return a DenseMatrix.
Raise an error is input is another type.
"""
if isinstance(a, (MatrixBase, StandardizedMatrix)):
return a
elif sps.issparse(a):
return SparseMatrix(a)
elif isinstance(a, np.ndarray):
return DenseMatrix(a)
else:
raise ValueError(f"Cannot convert type {type(a)} to Matrix.")


def split_sparse_and_dense_parts(
arg1: sps.csc_matrix, threshold: float = 0.1
) -> Tuple[DenseMatrix, SparseMatrix, np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -126,38 +145,45 @@ def __init__(
matrices: List[Union[DenseMatrix, SparseMatrix, CategoricalMatrix]],
indices: Optional[List[np.ndarray]] = None,
):
flatten_matrices = []
# First check that all matrices are valid types
for _, mat in enumerate(matrices):
for mat in matrices:
if not isinstance(mat, MatrixBase):
raise ValueError(
"Expected all elements of matrices to be subclasses of MatrixBase."
)
if isinstance(mat, SplitMatrix):
raise ValueError("Elements of matrices cannot be SplitMatrix.")
# Flatten out the SplitMatrix
for imat in mat.matrices:
flatten_matrices.append(imat)
else:
flatten_matrices.append(mat)

# Now that we know these are all MatrixBase, we can check consistent
# shapes and dtypes.
self.dtype = matrices[0].dtype
n_row = matrices[0].shape[0]
for i, mat in enumerate(matrices):
self.dtype = flatten_matrices[0].dtype
n_row = flatten_matrices[0].shape[0]
for i, mat in enumerate(flatten_matrices):
if mat.dtype != self.dtype:
warnings.warn(
"Matrices do not all have the same dtype. Dtypes are "
f"{[elt.dtype for elt in matrices]}."
f"{[elt.dtype for elt in flatten_matrices]}."
)
if not mat.shape[0] == n_row:
raise ValueError(
"All matrices should have the same first dimension, "
f"but the first matrix has first dimension {n_row} and matrix {i} has "
f"first dimension {mat.shape[0]}."
)
if len(mat.shape) != 2:
raise ValueError("All matrices should be two dimensional.")
if mat.ndim == 1:
flatten_matrices[i] = mat[:, np.newaxis]
elif mat.ndim > 2:
raise ValueError("All matrices should be at most two dimensional.")

if indices is None:
indices = []
current_idx = 0
for mat in matrices:
for mat in flatten_matrices:
indices.append(
np.arange(current_idx, current_idx + mat.shape[1], dtype=np.int64)
)
Expand All @@ -183,14 +209,14 @@ def __init__(

assert isinstance(indices, list)

for i, (mat, idx) in enumerate(zip(matrices, indices)):
for i, (mat, idx) in enumerate(zip(flatten_matrices, indices)):
if not mat.shape[1] == len(idx):
raise ValueError(
f"Element {i} of indices should should have length {mat.shape[1]}, "
f"but it has shape {idx.shape}"
)

filtered_mats, filtered_idxs = _filter_out_empty(matrices, indices)
filtered_mats, filtered_idxs = _filter_out_empty(flatten_matrices, indices)
combined_matrices, combined_indices = _combine_matrices(
filtered_mats, filtered_idxs
)
Expand Down
3 changes: 2 additions & 1 deletion src/quantcore/matrix/standardized_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
from scipy import sparse as sps

from . import MatrixBase, SparseMatrix
from .matrix_base import MatrixBase
from .sparse_matrix import SparseMatrix
from .util import (
check_transpose_matvec_out_shape,
set_up_rows_or_cols,
Expand Down
7 changes: 7 additions & 0 deletions tests/test_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,3 +579,10 @@ def test_pandas_to_matrix():
# was being changed in place.
assert df["cl_obj"].dtype == object
assert df["ds"].dtype == np.float64


@pytest.mark.parametrize("mat", get_all_matrix_base_subclass_mats())
def test_split_matrix_creation(mat):
sm = mx.SplitMatrix(matrices=[mat, mat])
assert sm.shape[0] == mat.shape[0]
assert sm.shape[1] == 2 * mat.shape[1]
12 changes: 7 additions & 5 deletions tests/test_split_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import scipy.sparse as sps

import quantcore.matrix as mx
from quantcore.matrix.dense_matrix import DenseMatrix
from quantcore.matrix.ext.sparse import csr_dense_sandwich
from quantcore.matrix.split_matrix import SplitMatrix, split_sparse_and_dense_parts

Expand Down Expand Up @@ -229,8 +230,9 @@ def check(mat):
many_random_tests(check)


def test_oned_dense_mat():
m1 = mx.CategoricalMatrix(np.random.randint(0, 10, 10))
m2 = mx.DenseMatrix(np.random.rand(10))
with pytest.raises(ValueError):
mx.SplitMatrix([m1, m2])
def test_init_from_1d():
m1 = DenseMatrix(np.arange(10, dtype=float))
m2 = DenseMatrix(np.ones(shape=(10, 2), dtype=float))

res = SplitMatrix([m1, m2])
assert res.shape == (10, 3)