Skip to content
34 changes: 24 additions & 10 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import numpy as np

import pandas._libs.lib as lib
from pandas._typing import FrameOrSeriesUnion
from pandas._typing import (
DtypeObj,
FrameOrSeriesUnion,
)
from pandas.util._decorators import Appender

from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -209,8 +212,12 @@ def _validate(data):
# see _libs/lib.pyx for list of inferred types
allowed_types = ["string", "empty", "bytes", "mixed", "mixed-integer"]

values = getattr(data, "values", data) # Series / Index
values = getattr(values, "categories", values) # categorical / normal
# TODO: avoid kludge for tests.extension.test_numpy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm cn you avoid this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can revert to use the getattr pattern, either way is an anti-pattern

from pandas.core.internals.managers import _extract_array

data = _extract_array(data)

values = getattr(data, "categories", data) # categorical / normal

inferred_dtype = lib.infer_dtype(values, skipna=True)

Expand Down Expand Up @@ -242,6 +249,7 @@ def _wrap_result(
expand: bool | None = None,
fill_value=np.nan,
returns_string=True,
returns_bool: bool = False,
):
from pandas import (
Index,
Expand Down Expand Up @@ -319,19 +327,25 @@ def cons_row(x):
else:
index = self._orig.index
# This is a mess.
dtype: str | None
if self._is_string and returns_string:
dtype = self._orig.dtype
dtype: DtypeObj | str | None
vdtype = getattr(result, "dtype", None)
if self._is_string:
if is_bool_dtype(vdtype):
dtype = result.dtype
elif returns_string:
dtype = self._orig.dtype
else:
dtype = vdtype
else:
dtype = None
dtype = vdtype

if expand:
cons = self._orig._constructor_expanddim
result = cons(result, columns=name, index=index, dtype=dtype)
else:
# Must be a Series
cons = self._orig._constructor
result = cons(result, name=name, index=index)
result = cons(result, name=name, index=index, dtype=dtype)
result = result.__finalize__(self._orig, method="str")
if name is not None and result.ndim == 1:
# __finalize__ might copy over the original name, but we may
Expand Down Expand Up @@ -369,7 +383,7 @@ def _get_series_list(self, others):
if isinstance(others, ABCSeries):
return [others]
elif isinstance(others, ABCIndex):
return [Series(others._values, index=idx)]
return [Series(others._values, index=idx, dtype=others.dtype)]
elif isinstance(others, ABCDataFrame):
return [others[x] for x in others]
elif isinstance(others, np.ndarray) and others.ndim == 2:
Expand Down Expand Up @@ -547,7 +561,7 @@ def cat(self, others=None, sep=None, na_rep=None, join="left"):
sep = ""

if isinstance(self._orig, ABCIndex):
data = Series(self._orig, index=self._orig)
data = Series(self._orig, index=self._orig, dtype=self._orig.dtype)
else: # Series
data = self._orig

Expand Down