From f86b9aab3421976fa311dccd6772a99f2fd0d2a5 Mon Sep 17 00:00:00 2001 From: lbittarello Date: Wed, 11 Sep 2024 17:00:35 +0100 Subject: [PATCH 1/3] Reindex categorical codes from Polars --- src/tabmat/categorical_matrix.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/tabmat/categorical_matrix.py b/src/tabmat/categorical_matrix.py index 60bc53d5..539479a6 100644 --- a/src/tabmat/categorical_matrix.py +++ b/src/tabmat/categorical_matrix.py @@ -245,7 +245,10 @@ def _extract_codes_and_categories(cat_vec): if not _is_polars(cat_vec.dtype): cat_vec = cat_vec.cast(pl.Categorical) categories = cat_vec.cat.get_categories().to_numpy() - indices = cat_vec.to_physical().fill_null(-1).to_numpy() + codes = cat_vec.to_physical().to_numpy() + # Remap the indices in case they don't start from zero or contain gaps + _, indices = np.unique(codes, return_inverse=True) + indices = np.where(cat_vec.is_null(), -1, indices) else: indices, categories = pd.factorize(cat_vec, sort=True) return indices, categories From db89cddee02ad82f5a4f4d348c495cee5b6d531e Mon Sep 17 00:00:00 2001 From: lbittarello Date: Thu, 12 Sep 2024 09:27:30 +0100 Subject: [PATCH 2/3] Localize --- src/tabmat/categorical_matrix.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/tabmat/categorical_matrix.py b/src/tabmat/categorical_matrix.py index 539479a6..e8200ab8 100644 --- a/src/tabmat/categorical_matrix.py +++ b/src/tabmat/categorical_matrix.py @@ -244,11 +244,8 @@ def _extract_codes_and_categories(cat_vec): elif _is_polars(cat_vec): if not _is_polars(cat_vec.dtype): cat_vec = cat_vec.cast(pl.Categorical) - categories = cat_vec.cat.get_categories().to_numpy() - codes = cat_vec.to_physical().to_numpy() - # Remap the indices in case they don't start from zero or contain gaps - _, indices = np.unique(codes, return_inverse=True) - indices = np.where(cat_vec.is_null(), -1, indices) + categories = cat_vec.cat.to_local().cat.get_categories().to_numpy() + indices = cat_vec.cat.to_local().to_physical().fill_null(-1).to_numpy() else: indices, categories = pd.factorize(cat_vec, sort=True) return indices, categories From 9a56ea8565dae84f05059e01a1cdee24cd16a252 Mon Sep 17 00:00:00 2001 From: Martin Stancsics Date: Thu, 12 Sep 2024 11:21:07 +0200 Subject: [PATCH 3/3] Add test for non-contiguous polars categoricals --- tests/test_categorical_matrix.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_categorical_matrix.py b/tests/test_categorical_matrix.py index 5d74d676..1c102a61 100644 --- a/tests/test_categorical_matrix.py +++ b/tests/test_categorical_matrix.py @@ -2,9 +2,10 @@ import numpy as np import pandas as pd +import polars as pl import pytest -from tabmat.categorical_matrix import CategoricalMatrix +from tabmat.categorical_matrix import CategoricalMatrix, _extract_codes_and_categories @pytest.fixture @@ -202,3 +203,13 @@ def test_categorical_indexing(drop_first, missing, cat_missing_method): dummy_na=cat_missing_method == "convert" and missing, ).to_numpy()[:, [0, 1]] np.testing.assert_allclose(mat[:, [0, 1]].toarray(), expected) + + +def test_polars_non_contiguous_codes(): + str_series = ["labrador", "boxer", "beagle"] + with pl.StringCache(): + _ = pl.Series(["beagle", "poodle", "labrador"], dtype=pl.Categorical) + cat_series = pl.Series(str_series, dtype=pl.Categorical) + + indices, categories = _extract_codes_and_categories(cat_series) + np.testing.assert_array_equal(str_series, categories[indices].tolist())