@@ -421,17 +421,21 @@ def str_extract(arr, pat, flags=0):
421421 Pattern or regular expression
422422 flags : int, default 0 (no flags)
423423 re module flags, e.g. re.IGNORECASE
424+ expand : None or bool, default None
425+ * If None, return Series/Index (one group) or DataFrame/MultiIndex (multiple groups)
426+ * If True, return DataFrame/MultiIndex expanding dimensionality.
427+ * If False, return Series/Index.
424428
425429 Returns
426430 -------
427- extracted groups : Series (one group) or DataFrame (multiple groups)
431+ extracted groups : Series/Index or DataFrame/MultiIndex of objects
428432 Note that dtype of the result is always object, even when no match is
429433 found and the result is a Series or DataFrame containing only NaN
430434 values.
431435
432436 Examples
433437 --------
434- A pattern with one group will return a Series. Non-matches will be NaN.
438+ A pattern with one group returns a Series. Non-matches will be NaN.
435439
436440 >>> Series(['a1', 'b2', 'c3']).str.extract('[ab](\d)')
437441 0 1
@@ -463,11 +467,14 @@ def str_extract(arr, pat, flags=0):
463467 1 b 2
464468 2 NaN NaN
465469
466- """
467- from pandas .core .series import Series
468- from pandas .core .frame import DataFrame
469- from pandas .core .index import Index
470+ Or you can specify ``expand=False`` to return Series.
470471
472+ >>> pd.Series(['a1', 'b2', 'c3']).str.extract('([ab])?(\d)', expand=False)
473+ 0 [a, 1]
474+ 1 [b, 2]
475+ 2 [nan, 3]
476+ Name: [0, 1], dtype: object
477+ """
471478 regex = re .compile (pat , flags = flags )
472479 # just to be safe, check this
473480 if regex .groups == 0 :
@@ -487,18 +494,9 @@ def f(x):
487494 result = np .array ([f (val )[0 ] for val in arr ], dtype = object )
488495 name = _get_single_group_name (regex )
489496 else :
490- if isinstance (arr , Index ):
491- raise ValueError ("only one regex group is supported with Index" )
492- name = None
493497 names = dict (zip (regex .groupindex .values (), regex .groupindex .keys ()))
494- columns = [names .get (1 + i , i ) for i in range (regex .groups )]
495- if arr .empty :
496- result = DataFrame (columns = columns , dtype = object )
497- else :
498- result = DataFrame ([f (val ) for val in arr ],
499- columns = columns ,
500- index = arr .index ,
501- dtype = object )
498+ name = [names .get (1 + i , i ) for i in range (regex .groups )]
499+ result = np .array ([f (val ) for val in arr ], dtype = object )
502500 return result , name
503501
504502
@@ -511,10 +509,13 @@ def str_get_dummies(arr, sep='|'):
511509 ----------
512510 sep : string, default "|"
513511 String to split on.
512+ expand : bool, default True
513+ * If True, return DataFrame/MultiIndex expanding dimensionality.
514+ * If False, return Series/Index.
514515
515516 Returns
516517 -------
517- dummies : DataFrame
518+ dummies : Series/Index or DataFrame/MultiIndex of objects
518519
519520 Examples
520521 --------
@@ -534,15 +535,15 @@ def str_get_dummies(arr, sep='|'):
534535 --------
535536 pandas.get_dummies
536537 """
537- from pandas .core .frame import DataFrame
538538 from pandas .core .index import Index
539-
540- # GH9980, Index.str does not support get_dummies() as it returns a frame
539+ # TODO: Add fillna GH 10089
541540 if isinstance (arr , Index ):
542- raise TypeError ("get_dummies is not supported for string methods on Index" )
543-
544- # TODO remove this hack?
545- arr = arr .fillna ('' )
541+ # temp hack
542+ values = arr .values
543+ values [isnull (values )] = ''
544+ arr = Index (values )
545+ else :
546+ arr = arr .fillna ('' )
546547 try :
547548 arr = sep + arr + sep
548549 except TypeError :
@@ -558,7 +559,7 @@ def str_get_dummies(arr, sep='|'):
558559 for i , t in enumerate (tags ):
559560 pat = sep + t + sep
560561 dummies [:, i ] = lib .map_infer (arr .values , lambda x : pat in x )
561- return DataFrame ( dummies , arr . index , tags )
562+ return dummies , tags
562563
563564
564565def str_join (arr , sep ):
@@ -1043,40 +1044,19 @@ def __iter__(self):
10431044 i += 1
10441045 g = self .get (i )
10451046
1046- def _wrap_result (self , result , ** kwargs ):
1047-
1048- # leave as it is to keep extract and get_dummies results
1049- # can be merged to _wrap_result_expand in v0.17
1050- from pandas .core .series import Series
1051- from pandas .core .frame import DataFrame
1052- from pandas .core .index import Index
1053-
1054- if not hasattr (result , 'ndim' ):
1055- return result
1056- name = kwargs .get ('name' ) or getattr (result , 'name' , None ) or self .series .name
1057-
1058- if result .ndim == 1 :
1059- if isinstance (self .series , Index ):
1060- # if result is a boolean np.array, return the np.array
1061- # instead of wrapping it into a boolean Index (GH 8875)
1062- if is_bool_dtype (result ):
1063- return result
1064- return Index (result , name = name )
1065- return Series (result , index = self .series .index , name = name )
1066- else :
1067- assert result .ndim < 3
1068- return DataFrame (result , index = self .series .index )
1047+ def _wrap_result (self , result , expand = False , name = None ):
1048+ from pandas .core .index import Index , MultiIndex
10691049
1070- def _wrap_result_expand (self , result , expand = False ):
10711050 if not isinstance (expand , bool ):
10721051 raise ValueError ("expand must be True or False" )
10731052
1074- from pandas .core .index import Index , MultiIndex
1053+ if name is None :
1054+ name = getattr (result , 'name' , None ) or self .series .name
1055+
10751056 if not hasattr (result , 'ndim' ):
10761057 return result
10771058
10781059 if isinstance (self .series , Index ):
1079- name = getattr (result , 'name' , None )
10801060 # if result is a boolean np.array, return the np.array
10811061 # instead of wrapping it into a boolean Index (GH 8875)
10821062 if hasattr (result , 'dtype' ) and is_bool_dtype (result ):
@@ -1092,10 +1072,12 @@ def _wrap_result_expand(self, result, expand=False):
10921072 if expand :
10931073 cons_row = self .series ._constructor
10941074 cons = self .series ._constructor_expanddim
1095- data = [cons_row (x ) for x in result ]
1096- return cons (data , index = index )
1075+ data = [cons_row (x , index = name ) for x in result ]
1076+ return cons (data , index = index , columns = name ,
1077+ dtype = result .dtype )
10971078 else :
1098- name = getattr (result , 'name' , None )
1079+ if result .ndim > 1 :
1080+ result = list (result )
10991081 cons = self .series ._constructor
11001082 return cons (result , name = name , index = index )
11011083
@@ -1109,7 +1091,7 @@ def cat(self, others=None, sep=None, na_rep=None):
11091091 @copy (str_split )
11101092 def split (self , pat = None , n = - 1 , expand = False ):
11111093 result = str_split (self .series , pat , n = n )
1112- return self ._wrap_result_expand (result , expand = expand )
1094+ return self ._wrap_result (result , expand = expand )
11131095
11141096 _shared_docs ['str_partition' ] = ("""
11151097 Split the string at the %(side)s occurrence of `sep`, and return 3 elements
@@ -1160,15 +1142,15 @@ def split(self, pat=None, n=-1, expand=False):
11601142 def partition (self , pat = ' ' , expand = True ):
11611143 f = lambda x : x .partition (pat )
11621144 result = _na_map (f , self .series )
1163- return self ._wrap_result_expand (result , expand = expand )
1145+ return self ._wrap_result (result , expand = expand )
11641146
11651147 @Appender (_shared_docs ['str_partition' ] % {'side' : 'last' ,
11661148 'return' : '3 elements containing two empty strings, followed by the string itself' ,
11671149 'also' : 'partition : Split the string at the first occurrence of `sep`' })
11681150 def rpartition (self , pat = ' ' , expand = True ):
11691151 f = lambda x : x .rpartition (pat )
11701152 result = _na_map (f , self .series )
1171- return self ._wrap_result_expand (result , expand = expand )
1153+ return self ._wrap_result (result , expand = expand )
11721154
11731155 @copy (str_get )
11741156 def get (self , i ):
@@ -1309,9 +1291,9 @@ def wrap(self, width, **kwargs):
13091291 return self ._wrap_result (result )
13101292
13111293 @copy (str_get_dummies )
1312- def get_dummies (self , sep = '|' ):
1313- result = str_get_dummies (self .series , sep )
1314- return self ._wrap_result (result )
1294+ def get_dummies (self , sep = '|' , expand = True ):
1295+ result , name = str_get_dummies (self .series , sep )
1296+ return self ._wrap_result (result , name = name , expand = expand )
13151297
13161298 @copy (str_translate )
13171299 def translate (self , table , deletechars = None ):
@@ -1324,9 +1306,18 @@ def translate(self, table, deletechars=None):
13241306 findall = _pat_wrapper (str_findall , flags = True )
13251307
13261308 @copy (str_extract )
1327- def extract (self , pat , flags = 0 ):
1309+ def extract (self , pat , flags = 0 , expand = None ):
13281310 result , name = str_extract (self .series , pat , flags = flags )
1329- return self ._wrap_result (result , name = name )
1311+ if expand is None and hasattr (result , 'ndim' ):
1312+ # to be compat with previous behavior
1313+ if len (result ) == 0 :
1314+ # for empty input
1315+ expand = True if isinstance (name , list ) else False
1316+ elif result .ndim > 1 :
1317+ expand = True
1318+ else :
1319+ expand = False
1320+ return self ._wrap_result (result , name = name , expand = expand )
13301321
13311322 _shared_docs ['find' ] = ("""
13321323 Return %(side)s indexes in each strings in the Series/Index
0 commit comments