diff --git a/src/quantcore/matrix/categorical_matrix.py b/src/quantcore/matrix/categorical_matrix.py index e95f8dd7..1f01c131 100644 --- a/src/quantcore/matrix/categorical_matrix.py +++ b/src/quantcore/matrix/categorical_matrix.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -7,6 +7,7 @@ from .ext.categorical import matvec, sandwich_categorical, transpose_matvec from .ext.split import sandwich_cat_cat, sandwich_cat_dense from .matrix_base import MatrixBase +from .sparse_matrix import SparseMatrix from .util import ( check_matvec_out_shape, check_transpose_matvec_out_shape, @@ -15,6 +16,23 @@ ) +def _is_indexer_full_length(full_length: int, indexer: Any): + if isinstance(indexer, int): + return full_length == 1 + elif isinstance(indexer, list): + if (np.asarray(indexer) > full_length - 1).any(): + raise IndexError("Index out-of-range.") + return len(set(indexer)) == full_length + elif isinstance(indexer, np.ndarray): + if (indexer > full_length - 1).any(): + raise IndexError("Index out-of-range.") + return len(np.unique(indexer)) == full_length + elif isinstance(indexer, slice): + return len(range(*indexer.indices(full_length))) == full_length + else: + raise ValueError(f"Indexing with {type(indexer)} is not allowed.") + + def _none_to_slice(arr: Optional[np.ndarray], n: int) -> Union[slice, np.ndarray]: if arr is None or len(arr) == n: return slice(None, None, None) @@ -262,8 +280,14 @@ def get_col_stds(self, weights: np.ndarray, col_means: np.ndarray) -> np.ndarray def __getitem__(self, item): if isinstance(item, tuple): row, col = item - if not (isinstance(col, slice) and col == slice(None, None, None)): - raise IndexError("Only column indexing is supported.") + if _is_indexer_full_length(self.shape[1], col): + if isinstance(row, int): + row = [row] + return CategoricalMatrix(self.cat[row]) + else: + # return a SparseMatrix if we subset columns + # TODO: this is inefficient. See issue #101. + return SparseMatrix(self.tocsr()[row, col], dtype=self.dtype) else: row = item if isinstance(row, int): diff --git a/tests/test_categorical_matrix.py b/tests/test_categorical_matrix.py index 876715fe..ec7e9618 100644 --- a/tests/test_categorical_matrix.py +++ b/tests/test_categorical_matrix.py @@ -58,3 +58,9 @@ def test_nulls(mi_element): vec = [0, mi_element, 1] with pytest.raises(ValueError, match="Categorical data can't have missing values"): CategoricalMatrix(vec) + + +def test_categorical_indexing(): + catvec = [0, 1, 2, 0, 1, 2] + mat = CategoricalMatrix(catvec) + mat[:, [0, 1]]