Skip to content
29 changes: 21 additions & 8 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,20 +1354,33 @@ def object_dtype(request):

@pytest.fixture(
params=[
"object",
"string[python]",
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")),
]
np.dtype("object"),
("python", pd.NA),
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
],
ids=[
"string=object",
"string=string[python]",
"string=string[pyarrow]",
"string=str[pyarrow]",
],
)
def any_string_dtype(request):
"""
Parametrized fixture for string dtypes.
* 'object'
* 'string[python]'
* 'string[pyarrow]'
* 'string[python]' (NA variant)
* 'string[pyarrow]' (NA variant)
* 'str' (NaN variant, with pyarrow)
"""
return request.param
if isinstance(request.param, np.dtype):
return request.param
else:
# need to instantiate the StringDtype here instead of in the params
# to avoid importing pyarrow during test collection
storage, na_value = request.param
return pd.StringDtype(storage, na_value)


@pytest.fixture(params=tm.DATETIME64_DTYPES)
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
) -> None:
# infer defaults
if storage is None:
if using_string_dtype():
if using_string_dtype() and na_value is not libmissing.NA:
storage = "pyarrow"
else:
storage = get_option("mode.string_storage")
Expand Down Expand Up @@ -167,7 +167,9 @@ def __eq__(self, other: object) -> bool:
return True
try:
other = self.construct_from_string(other)
except TypeError:
except (TypeError, ImportError):
# TypeError if `other` is not a valid string for StringDtype
# ImportError if pyarrow is not installed for "string[pyarrow]"
return False
if isinstance(other, type(self)):
return self.storage == other.storage and self.na_value is other.na_value
Expand Down
1 change: 0 additions & 1 deletion pandas/tests/arrays/categorical/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,6 @@ def test_interval(self):
tm.assert_numpy_array_equal(cat.codes, expected_codes)
tm.assert_index_equal(cat.categories, idx)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_categorical_extension_array_nullable(self, nulls_fixture):
# GH:
arr = pd.arrays.StringArray._from_sequence(
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/copy_view/test_array.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas import (
DataFrame,
Series,
Expand Down Expand Up @@ -119,7 +117,6 @@ def test_dataframe_array_ea_dtypes():
assert arr.flags.writeable is False


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_dataframe_array_string_dtype():
df = DataFrame({"a": ["a", "b"]}, dtype="string")
arr = np.asarray(df)
Expand Down
2 changes: 0 additions & 2 deletions pandas/tests/copy_view/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def test_astype_numpy_to_ea():
assert np.shares_memory(get_array(ser), get_array(result))


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.parametrize(
"dtype, new_dtype", [("object", "string"), ("string", "object")]
)
Expand All @@ -98,7 +97,6 @@ def test_astype_string_and_object(dtype, new_dtype):
tm.assert_frame_equal(df, df_orig)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.parametrize(
"dtype, new_dtype", [("object", "string"), ("string", "object")]
)
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/dtypes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

import pandas.util._test_decorators as td

from pandas.core.dtypes.astype import astype_array
Expand Down Expand Up @@ -130,7 +128,6 @@ def test_dtype_equal(name1, dtype1, name2, dtype2):
assert not com.is_dtype_equal(dtype1, dtype2)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.parametrize("name,dtype", list(dtypes.items()), ids=lambda x: str(x))
def test_pyarrow_string_import_error(name, dtype):
# GH-44276
Expand Down
3 changes: 3 additions & 0 deletions pandas/tests/io/parser/test_index_col.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas import (
DataFrame,
Index,
Expand Down Expand Up @@ -343,6 +345,7 @@ def test_infer_types_boolean_sum(all_parsers):
tm.assert_frame_equal(result, expected, check_index_type=False)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.parametrize("dtype, val", [(object, "01"), ("int64", 1)])
def test_specify_dtype_for_index_col(all_parsers, dtype, val, request):
# GH#9435
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/series/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,7 +2118,7 @@ def test_series_string_inference_storage_definition(self):
# returning the NA string dtype, so expected is changed from
# "string[pyarrow_numpy]" to "string[pyarrow]"
pytest.importorskip("pyarrow")
expected = Series(["a", "b"], dtype="string[pyarrow]")
expected = Series(["a", "b"], dtype="string[python]")
with pd.option_context("future.infer_string", True):
result = Series(["a", "b"], dtype="string")
tm.assert_series_equal(result, expected)
Expand Down
10 changes: 9 additions & 1 deletion pandas/tests/strings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

import pandas as pd

object_pyarrow_numpy = ("object", "string[pyarrow_numpy]")

def is_object_or_nan_string_dtype(dtype):
"""
Check if string-like dtype is following NaN semantics, i.e. is object
dtype or a NaN-variant of the StringDtype.
"""
return (isinstance(dtype, np.dtype) and dtype == "object") or (
dtype.na_value is np.nan
)


def _convert_na_value(ser, expected):
Expand Down
70 changes: 52 additions & 18 deletions pandas/tests/strings/test_find_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from pandas.tests.strings import (
_convert_na_value,
object_pyarrow_numpy,
is_object_or_nan_string_dtype,
)

# --------------------------------------------------------------------------------------
Expand All @@ -33,7 +33,9 @@ def test_contains(any_string_dtype):
pat = "mmm[_]+"

result = values.str.contains(pat)
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series(
np.array([False, np.nan, True, True, False], dtype=np.object_),
dtype=expected_dtype,
Expand All @@ -52,7 +54,9 @@ def test_contains(any_string_dtype):
dtype=any_string_dtype,
)
result = values.str.contains(pat)
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand All @@ -79,14 +83,18 @@ def test_contains(any_string_dtype):
pat = "mmm[_]+"

result = values.str.contains(pat)
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series(
np.array([False, np.nan, True, True], dtype=np.object_), dtype=expected_dtype
)
tm.assert_series_equal(result, expected)

result = values.str.contains(pat, na=False)
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series(np.array([False, False, True, True]), dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -171,7 +179,9 @@ def test_contains_moar(any_string_dtype):
)

result = s.str.contains("a")
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series(
[False, False, False, True, True, False, np.nan, False, False, True],
dtype=expected_dtype,
Expand Down Expand Up @@ -212,7 +222,9 @@ def test_contains_nan(any_string_dtype):
s = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype)

result = s.str.contains("foo", na=False)
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series([False, False, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand All @@ -230,7 +242,9 @@ def test_contains_nan(any_string_dtype):
tm.assert_series_equal(result, expected)

result = s.str.contains("foo")
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series([np.nan, np.nan, np.nan], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -675,7 +689,9 @@ def test_replace_regex_single_character(regex, any_string_dtype):

def test_match(any_string_dtype):
# New match behavior introduced in 0.13
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)

values = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype)
result = values.str.match(".*(BAD[_]+).*(BAD)")
Expand Down Expand Up @@ -730,20 +746,26 @@ def test_match_na_kwarg(any_string_dtype):
s = Series(["a", "b", np.nan], dtype=any_string_dtype)

result = s.str.match("a", na=False)
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series([True, False, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

result = s.str.match("a")
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series([True, False, np.nan], dtype=expected_dtype)
tm.assert_series_equal(result, expected)


def test_match_case_kwarg(any_string_dtype):
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
result = values.str.match("ab", case=False)
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series([True, True, True, True], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand All @@ -759,7 +781,9 @@ def test_fullmatch(any_string_dtype):
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
)
result = ser.str.fullmatch(".*BAD[_]+.*BAD")
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series([True, False, np.nan, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand All @@ -768,7 +792,9 @@ def test_fullmatch_dollar_literal(any_string_dtype):
# GH 56652
ser = Series(["foo", "foo$foo", np.nan, "foo$"], dtype=any_string_dtype)
result = ser.str.fullmatch("foo\\$")
expected_dtype = "object" if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
"object" if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series([False, False, np.nan, True], dtype=expected_dtype)
tm.assert_series_equal(result, expected)

Expand All @@ -778,14 +804,18 @@ def test_fullmatch_na_kwarg(any_string_dtype):
["fooBAD__barBAD", "BAD_BADleroybrown", np.nan, "foo"], dtype=any_string_dtype
)
result = ser.str.fullmatch(".*BAD[_]+.*BAD", na=False)
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)
expected = Series([True, False, False, False], dtype=expected_dtype)
tm.assert_series_equal(result, expected)


def test_fullmatch_case_kwarg(any_string_dtype, performance_warning):
ser = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
expected_dtype = np.bool_ if any_string_dtype in object_pyarrow_numpy else "boolean"
expected_dtype = (
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
)

expected = Series([True, False, False, False], dtype=expected_dtype)

Expand Down Expand Up @@ -859,7 +889,9 @@ def test_find(any_string_dtype):
ser = Series(
["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXXX"], dtype=any_string_dtype
)
expected_dtype = np.int64 if any_string_dtype in object_pyarrow_numpy else "Int64"
expected_dtype = (
np.int64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
)

result = ser.str.find("EF")
expected = Series([4, 3, 1, 0, -1], dtype=expected_dtype)
Expand Down Expand Up @@ -911,7 +943,9 @@ def test_find_nan(any_string_dtype):
ser = Series(
["ABCDEFG", np.nan, "DEFGHIJEF", np.nan, "XXXX"], dtype=any_string_dtype
)
expected_dtype = np.float64 if any_string_dtype in object_pyarrow_numpy else "Int64"
expected_dtype = (
np.float64 if is_object_or_nan_string_dtype(any_string_dtype) else "Int64"
)

result = ser.str.find("EF")
expected = Series([4, np.nan, 1, np.nan, -1], dtype=expected_dtype)
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/strings/test_split_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from pandas.tests.strings import (
_convert_na_value,
object_pyarrow_numpy,
is_object_or_nan_string_dtype,
)


Expand Down Expand Up @@ -385,7 +385,7 @@ def test_split_nan_expand(any_string_dtype):
# check that these are actually np.nan/pd.NA and not None
# TODO see GH 18463
# tm.assert_frame_equal does not differentiate
if any_string_dtype in object_pyarrow_numpy:
if is_object_or_nan_string_dtype(any_string_dtype):
assert all(np.isnan(x) for x in result.iloc[1])
else:
assert all(x is pd.NA for x in result.iloc[1])
Expand Down
Loading