@@ -3086,52 +3086,38 @@ def _get_group_names(regex: Pattern) -> List[Hashable]:
30863086
30873087def _str_extract_noexpand (arr , pat , flags = 0 ):
30883088 """
3089- Find groups in each string in the Series using passed regular
3090- expression. This function is called from
3091- str_extract(expand=False), and can return Series, DataFrame, or
3092- Index.
3089+ Find groups in each string in the Series/Index using passed regular expression.
30933090
3091+ This function is called from str_extract(expand=False) when there is a single group
3092+ in the regex.
3093+
3094+ Returns
3095+ -------
3096+ np.ndarray
30943097 """
3095- from pandas import (
3096- DataFrame ,
3097- array as pd_array ,
3098- )
3098+ from pandas import array as pd_array
30993099
31003100 regex = re .compile (pat , flags = flags )
31013101 groups_or_na = _groups_or_na_fun (regex )
31023102 result_dtype = _result_dtype (arr )
31033103
3104- if regex .groups == 1 :
3105- result = np .array ([groups_or_na (val )[0 ] for val in arr ], dtype = object )
3106- name = _get_single_group_name (regex )
3107- # not dispatching, so we have to reconstruct here.
3108- result = pd_array (result , dtype = result_dtype )
3109- else :
3110- name = None
3111- columns = _get_group_names (regex )
3112- if arr .size == 0 :
3113- # error: Incompatible types in assignment (expression has type
3114- # "DataFrame", variable has type "ndarray")
3115- result = DataFrame ( # type: ignore[assignment]
3116- columns = columns , dtype = result_dtype
3117- )
3118- else :
3119- # error: Incompatible types in assignment (expression has type
3120- # "DataFrame", variable has type "ndarray")
3121- result = DataFrame ( # type:ignore[assignment]
3122- [groups_or_na (val ) for val in arr ],
3123- columns = columns ,
3124- index = arr .index ,
3125- dtype = result_dtype ,
3126- )
3127- return result , name
3104+ result = np .array ([groups_or_na (val )[0 ] for val in arr ], dtype = object )
3105+ # not dispatching, so we have to reconstruct here.
3106+ result = pd_array (result , dtype = result_dtype )
3107+ return result
31283108
31293109
31303110def _str_extract_frame (arr , pat , flags = 0 ):
31313111 """
3132- For each subject string in the Series, extract groups from the
3133- first match of regular expression pat. This function is called from
3134- str_extract(expand=True), and always returns a DataFrame.
3112+ Find groups in each string in the Series/Index using passed regular expression.
3113+
3114+ For each subject string in the Series/Index, extract groups from the first match of
3115+ regular expression pat. This function is called from str_extract(expand=True) or
3116+ str_extract(expand=False) when there is more than one group in the regex.
3117+
3118+ Returns
3119+ -------
3120+ DataFrame
31353121
31363122 """
31373123 from pandas import DataFrame
@@ -3141,11 +3127,13 @@ def _str_extract_frame(arr, pat, flags=0):
31413127 columns = _get_group_names (regex )
31423128 result_dtype = _result_dtype (arr )
31433129
3144- if len ( arr ) == 0 :
3130+ if arr . size == 0 :
31453131 return DataFrame (columns = columns , dtype = result_dtype )
3146- try :
3132+
3133+ result_index : Optional ["Index" ]
3134+ if isinstance (arr , ABCSeries ):
31473135 result_index = arr .index
3148- except AttributeError :
3136+ else :
31493137 result_index = None
31503138 return DataFrame (
31513139 [groups_or_na (val ) for val in arr ],
@@ -3156,12 +3144,16 @@ def _str_extract_frame(arr, pat, flags=0):
31563144
31573145
31583146def str_extract (arr , pat , flags = 0 , expand = True ):
3159- if expand :
3147+ regex = re .compile (pat , flags = flags )
3148+ returns_df = regex .groups > 1 or expand
3149+
3150+ if returns_df :
3151+ name = None
31603152 result = _str_extract_frame (arr ._orig , pat , flags = flags )
3161- return result .__finalize__ (arr ._orig , method = "str_extract" )
31623153 else :
3163- result , name = _str_extract_noexpand (arr ._orig , pat , flags = flags )
3164- return arr ._wrap_result (result , name = name , expand = expand )
3154+ name = _get_single_group_name (regex )
3155+ result = _str_extract_noexpand (arr ._orig , pat , flags = flags )
3156+ return arr ._wrap_result (result , name = name )
31653157
31663158
31673159def str_extractall (arr , pat , flags = 0 ):
0 commit comments