8080)
8181
8282from pandas import (
83+ ArrowDtype ,
8384 Categorical ,
8485 Index ,
8586 MultiIndex ,
8687 Series ,
8788)
8889import pandas .core .algorithms as algos
8990from pandas .core .arrays import (
91+ ArrowExtensionArray ,
9092 BaseMaskedArray ,
9193 ExtensionArray ,
9294)
@@ -2377,7 +2379,11 @@ def _factorize_keys(
23772379 rk = ensure_int64 (rk .codes )
23782380
23792381 elif isinstance (lk , ExtensionArray ) and is_dtype_equal (lk .dtype , rk .dtype ):
2380- if not isinstance (lk , BaseMaskedArray ):
2382+ if not isinstance (lk , BaseMaskedArray ) and not (
2383+ # exclude arrow dtypes that would get cast to object
2384+ isinstance (lk .dtype , ArrowDtype )
2385+ and is_numeric_dtype (lk .dtype .numpy_dtype )
2386+ ):
23812387 lk , _ = lk ._values_for_factorize ()
23822388
23832389 # error: Item "ndarray" of "Union[Any, ndarray]" has no attribute
@@ -2392,6 +2398,16 @@ def _factorize_keys(
23922398 assert isinstance (rk , BaseMaskedArray )
23932399 llab = rizer .factorize (lk ._data , mask = lk ._mask )
23942400 rlab = rizer .factorize (rk ._data , mask = rk ._mask )
2401+ elif isinstance (lk , ArrowExtensionArray ):
2402+ assert isinstance (rk , ArrowExtensionArray )
2403+ # we can only get here with numeric dtypes
2404+ # TODO: Remove when we have a Factorizer for Arrow
2405+ llab = rizer .factorize (
2406+ lk .to_numpy (na_value = 1 , dtype = lk .dtype .numpy_dtype ), mask = lk .isna ()
2407+ )
2408+ rlab = rizer .factorize (
2409+ rk .to_numpy (na_value = 1 , dtype = lk .dtype .numpy_dtype ), mask = rk .isna ()
2410+ )
23952411 else :
23962412 # Argument 1 to "factorize" of "ObjectFactorizer" has incompatible type
23972413 # "Union[ndarray[Any, dtype[signedinteger[_64Bit]]],
@@ -2450,6 +2466,8 @@ def _convert_arrays_and_get_rizer_klass(
24502466 # Invalid index type "type" for "Dict[Type[object], Type[Factorizer]]";
24512467 # expected type "Type[object]"
24522468 klass = _factorizers [lk .dtype .type ] # type: ignore[index]
2469+ elif isinstance (lk .dtype , ArrowDtype ):
2470+ klass = _factorizers [lk .dtype .numpy_dtype .type ]
24532471 else :
24542472 klass = _factorizers [lk .dtype .type ]
24552473
0 commit comments