@@ -751,6 +751,138 @@ def _convert_level_number(level_num, columns):
751751 return result
752752
753753
754+ def from_dummies (data , columns = None , prefix_sep = "_" , dtype = "category" , fill_first = None ):
755+ """
756+ The inverse transformation of ``pandas.get_dummies``.
757+
758+ Parameters
759+ ----------
760+ data : DataFrame
761+ columns : list-like, default None
762+ Column names in the DataFrame to be decoded.
763+ If `columns` is None then all the columns will be converted.
764+ prefix_sep : str, default '_'
765+ Separator between original column name and dummy variable
766+ dtype : dtype, default 'category'
767+ Data dtype for new columns - only a single data type is allowed
768+ fill_first : str, list, or dict, default None
769+ Used to fill rows for which all the dummy variables are 0
770+
771+ Returns
772+ -------
773+ transformed : DataFrame
774+
775+ Examples
776+ --------
777+ Say we have a dataframe where some variables have been dummified:
778+
779+ >>> df = pd.DataFrame(
780+ ... {
781+ ... "animal_baboon": [0, 0, 1],
782+ ... "animal_lemur": [0, 1, 0],
783+ ... "animal_zebra": [1, 0, 0],
784+ ... "other_col": ["a", "b", "c"],
785+ ... }
786+ ... )
787+ >>> df
788+ animal_baboon animal_lemur animal_zebra other_col
789+ 0 0 0 1 a
790+ 1 0 1 0 b
791+ 2 1 0 0 c
792+
793+ We can recover the original dataframe using `from_dummies`:
794+
795+ >>> pd.from_dummies(df, columns=['animal'])
796+ other_col animal
797+ 0 a zebra
798+ 1 b lemur
799+ 2 c baboon
800+
801+ Suppose our dataframe has one column from each dummified column
802+ dropped:
803+
804+ >>> df = df.drop('animal_zebra', axis=1)
805+ >>> df
806+ animal_baboon animal_lemur other_col
807+ 0 0 0 a
808+ 1 0 1 b
809+ 2 1 0 c
810+
811+ We can still recover the original dataframe, by using the argument
812+ `fill_first`:
813+
814+ >>> pd.from_dummies(df, columns=["animal"], fill_first=["zebra"])
815+ other_col animal
816+ 0 a zebra
817+ 1 b lemur
818+ 2 c baboon
819+ """
820+ if dtype is None :
821+ dtype = "category"
822+
823+ if columns is None :
824+ data_to_decode = data .copy ()
825+ columns = data .columns .tolist ()
826+ columns = list (
827+ {i .split (prefix_sep )[0 ] for i in data .columns if prefix_sep in i }
828+ )
829+
830+ data_to_decode = data [
831+ [i for i in data .columns for c in columns if i .startswith (c + prefix_sep )]
832+ ]
833+
834+ # Check each row sums to 1 or 0
835+ if not all (i in [0 , 1 ] for i in data_to_decode .sum (axis = 1 ).unique ().tolist ()):
836+ raise ValueError (
837+ "Data cannot be decoded! Each row must contain only 0s and"
838+ " 1s, and each row may have at most one 1"
839+ )
840+
841+ if fill_first is None :
842+ fill_first = [None ] * len (columns )
843+ elif isinstance (fill_first , str ):
844+ fill_first = itertools .cycle ([fill_first ])
845+ elif isinstance (fill_first , dict ):
846+ fill_first = [fill_first [col ] for col in columns ]
847+
848+ out = data .copy ()
849+ for column , fill_first_ in zip (columns , fill_first ):
850+ cols , labels = [
851+ [
852+ i .replace (x , "" )
853+ for i in data_to_decode .columns
854+ if column + prefix_sep in i
855+ ]
856+ for x in ["" , column + prefix_sep ]
857+ ]
858+ if not cols :
859+ continue
860+ out = out .drop (cols , axis = 1 )
861+ if fill_first_ :
862+ cols = [column + prefix_sep + fill_first_ ] + cols
863+ labels = [fill_first_ ] + labels
864+ data [cols [0 ]] = (1 - data [cols [1 :]]).all (axis = 1 )
865+ out [column ] = Series (
866+ np .array (labels )[np .argmax (data [cols ].to_numpy (), axis = 1 )], dtype = dtype
867+ )
868+ return out
869+
870+
871+ def _check_len (item , name , data_to_encode ):
872+ """ Validate prefixes and separator to avoid silently dropping cols. """
873+ len_msg = (
874+ "Length of '{name}' ({len_item}) did not match the "
875+ "length of the columns being encoded ({len_enc})."
876+ )
877+
878+ if is_list_like (item ):
879+ if not len (item ) == data_to_encode .shape [1 ]:
880+ len_msg = len_msg .format (
881+ name = name , len_item = len (item ), len_enc = data_to_encode .shape [1 ]
882+ )
883+ raise ValueError (len_msg )
884+
885+
754886def get_dummies (
755887 data ,
756888 prefix = None ,
@@ -871,20 +1003,8 @@ def get_dummies(
8711003 else :
8721004 data_to_encode = data [columns ]
8731005
874- # validate prefixes and separator to avoid silently dropping cols
875- def check_len (item , name ):
876-
877- if is_list_like (item ):
878- if not len (item ) == data_to_encode .shape [1 ]:
879- len_msg = (
880- f"Length of '{ name } ' ({ len (item )} ) did not match the "
881- "length of the columns being encoded "
882- f"({ data_to_encode .shape [1 ]} )."
883- )
884- raise ValueError (len_msg )
885-
886- check_len (prefix , "prefix" )
887- check_len (prefix_sep , "prefix_sep" )
1006+ _check_len (prefix , "prefix" , data_to_encode )
1007+ _check_len (prefix_sep , "prefix_sep" , data_to_encode )
8881008
8891009 if isinstance (prefix , str ):
8901010 prefix = itertools .cycle ([prefix ])
0 commit comments