From c723727107e19016e1bfff19c8193062e36bca55 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Schmidt Date: Sat, 28 Aug 2021 13:37:37 -0400 Subject: [PATCH 1/2] categorical indexing --- src/quantcore/matrix/categorical_matrix.py | 30 +++++++++++++++++++--- tests/test_matrices.py | 6 +++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/quantcore/matrix/categorical_matrix.py b/src/quantcore/matrix/categorical_matrix.py index e95f8dd7..1f01c131 100644 --- a/src/quantcore/matrix/categorical_matrix.py +++ b/src/quantcore/matrix/categorical_matrix.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -7,6 +7,7 @@ from .ext.categorical import matvec, sandwich_categorical, transpose_matvec from .ext.split import sandwich_cat_cat, sandwich_cat_dense from .matrix_base import MatrixBase +from .sparse_matrix import SparseMatrix from .util import ( check_matvec_out_shape, check_transpose_matvec_out_shape, @@ -15,6 +16,23 @@ ) +def _is_indexer_full_length(full_length: int, indexer: Any): + if isinstance(indexer, int): + return full_length == 1 + elif isinstance(indexer, list): + if (np.asarray(indexer) > full_length - 1).any(): + raise IndexError("Index out-of-range.") + return len(set(indexer)) == full_length + elif isinstance(indexer, np.ndarray): + if (indexer > full_length - 1).any(): + raise IndexError("Index out-of-range.") + return len(np.unique(indexer)) == full_length + elif isinstance(indexer, slice): + return len(range(*indexer.indices(full_length))) == full_length + else: + raise ValueError(f"Indexing with {type(indexer)} is not allowed.") + + def _none_to_slice(arr: Optional[np.ndarray], n: int) -> Union[slice, np.ndarray]: if arr is None or len(arr) == n: return slice(None, None, None) @@ -262,8 +280,14 @@ def get_col_stds(self, weights: np.ndarray, col_means: np.ndarray) -> np.ndarray def __getitem__(self, item): if isinstance(item, tuple): row, col = item - if not (isinstance(col, slice) and col == slice(None, None, None)): - raise IndexError("Only column indexing is supported.") + if _is_indexer_full_length(self.shape[1], col): + if isinstance(row, int): + row = [row] + return CategoricalMatrix(self.cat[row]) + else: + # return a SparseMatrix if we subset columns + # TODO: this is inefficient. See issue #101. + return SparseMatrix(self.tocsr()[row, col], dtype=self.dtype) else: row = item if isinstance(row, int): diff --git a/tests/test_matrices.py b/tests/test_matrices.py index cadf0407..03ff40a1 100644 --- a/tests/test_matrices.py +++ b/tests/test_matrices.py @@ -579,3 +579,9 @@ def test_pandas_to_matrix(): # was being changed in place. assert df["cl_obj"].dtype == object assert df["ds"].dtype == np.float64 + + +def test_categorical_indexing(): + catvec = [0, 1, 2, 0, 1, 2] + mat = mx.CategoricalMatrix(catvec) + mat[:, [0, 1]] From 024e473ce0596ea8d3b55b06110650cfddd2ef35 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Schmidt Date: Sat, 28 Aug 2021 13:44:39 -0400 Subject: [PATCH 2/2] moved test --- tests/test_categorical_matrix.py | 6 ++++++ tests/test_matrices.py | 6 ------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_categorical_matrix.py b/tests/test_categorical_matrix.py index 876715fe..ec7e9618 100644 --- a/tests/test_categorical_matrix.py +++ b/tests/test_categorical_matrix.py @@ -58,3 +58,9 @@ def test_nulls(mi_element): vec = [0, mi_element, 1] with pytest.raises(ValueError, match="Categorical data can't have missing values"): CategoricalMatrix(vec) + + +def test_categorical_indexing(): + catvec = [0, 1, 2, 0, 1, 2] + mat = CategoricalMatrix(catvec) + mat[:, [0, 1]] diff --git a/tests/test_matrices.py b/tests/test_matrices.py index 03ff40a1..cadf0407 100644 --- a/tests/test_matrices.py +++ b/tests/test_matrices.py @@ -579,9 +579,3 @@ def test_pandas_to_matrix(): # was being changed in place. assert df["cl_obj"].dtype == object assert df["ds"].dtype == np.float64 - - -def test_categorical_indexing(): - catvec = [0, 1, 2, 0, 1, 2] - mat = mx.CategoricalMatrix(catvec) - mat[:, [0, 1]]