diff --git a/src/tabmat/categorical_matrix.py b/src/tabmat/categorical_matrix.py index 60bc53d5..e8200ab8 100644 --- a/src/tabmat/categorical_matrix.py +++ b/src/tabmat/categorical_matrix.py @@ -244,8 +244,8 @@ def _extract_codes_and_categories(cat_vec): elif _is_polars(cat_vec): if not _is_polars(cat_vec.dtype): cat_vec = cat_vec.cast(pl.Categorical) - categories = cat_vec.cat.get_categories().to_numpy() - indices = cat_vec.to_physical().fill_null(-1).to_numpy() + categories = cat_vec.cat.to_local().cat.get_categories().to_numpy() + indices = cat_vec.cat.to_local().to_physical().fill_null(-1).to_numpy() else: indices, categories = pd.factorize(cat_vec, sort=True) return indices, categories diff --git a/tests/test_categorical_matrix.py b/tests/test_categorical_matrix.py index 5d74d676..1c102a61 100644 --- a/tests/test_categorical_matrix.py +++ b/tests/test_categorical_matrix.py @@ -2,9 +2,10 @@ import numpy as np import pandas as pd +import polars as pl import pytest -from tabmat.categorical_matrix import CategoricalMatrix +from tabmat.categorical_matrix import CategoricalMatrix, _extract_codes_and_categories @pytest.fixture @@ -202,3 +203,13 @@ def test_categorical_indexing(drop_first, missing, cat_missing_method): dummy_na=cat_missing_method == "convert" and missing, ).to_numpy()[:, [0, 1]] np.testing.assert_allclose(mat[:, [0, 1]].toarray(), expected) + + +def test_polars_non_contiguous_codes(): + str_series = ["labrador", "boxer", "beagle"] + with pl.StringCache(): + _ = pl.Series(["beagle", "poodle", "labrador"], dtype=pl.Categorical) + cat_series = pl.Series(str_series, dtype=pl.Categorical) + + indices, categories = _extract_codes_and_categories(cat_series) + np.testing.assert_array_equal(str_series, categories[indices].tolist())