22
33import abc
44from collections import defaultdict
5+ import functools
56from functools import partial
67import inspect
78from typing import (
2930 NDFrameT ,
3031 npt ,
3132)
33+ from pandas .compat ._optional import import_optional_dependency
3234from pandas .errors import SpecificationError
3335from pandas .util ._decorators import cache_readonly
3436from pandas .util ._exceptions import find_stack_level
3537
3638from pandas .core .dtypes .cast import is_nested_object
3739from pandas .core .dtypes .common import (
3840 is_dict_like ,
41+ is_extension_array_dtype ,
3942 is_list_like ,
43+ is_numeric_dtype ,
4044 is_sequence ,
4145)
4246from pandas .core .dtypes .dtypes import (
@@ -121,6 +125,8 @@ def __init__(
121125 result_type : str | None ,
122126 * ,
123127 by_row : Literal [False , "compat" , "_compat" ] = "compat" ,
128+ engine : str = "python" ,
129+ engine_kwargs : dict [str , bool ] | None = None ,
124130 args ,
125131 kwargs ,
126132 ) -> None :
@@ -133,6 +139,9 @@ def __init__(
133139 self .args = args or ()
134140 self .kwargs = kwargs or {}
135141
142+ self .engine = engine
143+ self .engine_kwargs = {} if engine_kwargs is None else engine_kwargs
144+
136145 if result_type not in [None , "reduce" , "broadcast" , "expand" ]:
137146 raise ValueError (
138147 "invalid value for result_type, must be one "
@@ -601,6 +610,13 @@ def apply_list_or_dict_like(self) -> DataFrame | Series:
601610 result: Series, DataFrame, or None
602611 Result when self.func is a list-like or dict-like, None otherwise.
603612 """
613+
614+ if self .engine == "numba" :
615+ raise NotImplementedError (
616+ "The 'numba' engine doesn't support list-like/"
617+ "dict likes of callables yet."
618+ )
619+
604620 if self .axis == 1 and isinstance (self .obj , ABCDataFrame ):
605621 return self .obj .T .apply (self .func , 0 , args = self .args , ** self .kwargs ).T
606622
@@ -768,10 +784,16 @@ def __init__(
768784 ) -> None :
769785 if by_row is not False and by_row != "compat" :
770786 raise ValueError (f"by_row={ by_row } not allowed" )
771- self .engine = engine
772- self .engine_kwargs = engine_kwargs
773787 super ().__init__ (
774- obj , func , raw , result_type , by_row = by_row , args = args , kwargs = kwargs
788+ obj ,
789+ func ,
790+ raw ,
791+ result_type ,
792+ by_row = by_row ,
793+ engine = engine ,
794+ engine_kwargs = engine_kwargs ,
795+ args = args ,
796+ kwargs = kwargs ,
775797 )
776798
777799 # ---------------------------------------------------------------
@@ -792,6 +814,32 @@ def result_columns(self) -> Index:
792814 def series_generator (self ) -> Generator [Series , None , None ]:
793815 pass
794816
817+ @staticmethod
818+ @functools .cache
819+ @abc .abstractmethod
820+ def generate_numba_apply_func (
821+ func , nogil = True , nopython = True , parallel = False
822+ ) -> Callable [[npt .NDArray , Index , Index ], dict [int , Any ]]:
823+ pass
824+
825+ @abc .abstractmethod
826+ def apply_with_numba (self ):
827+ pass
828+
829+ def validate_values_for_numba (self ):
830+ # Validate column dtyps all OK
831+ for colname , dtype in self .obj .dtypes .items ():
832+ if not is_numeric_dtype (dtype ):
833+ raise ValueError (
834+ f"Column { colname } must have a numeric dtype. "
835+ f"Found '{ dtype } ' instead"
836+ )
837+ if is_extension_array_dtype (dtype ):
838+ raise ValueError (
839+ f"Column { colname } is backed by an extension array, "
840+ f"which is not supported by the numba engine."
841+ )
842+
795843 @abc .abstractmethod
796844 def wrap_results_for_axis (
797845 self , results : ResType , res_index : Index
@@ -815,13 +863,12 @@ def values(self):
815863 def apply (self ) -> DataFrame | Series :
816864 """compute the results"""
817865
818- if self .engine == "numba" and not self .raw :
819- raise ValueError (
820- "The numba engine in DataFrame.apply can only be used when raw=True"
821- )
822-
823866 # dispatch to handle list-like or dict-like
824867 if is_list_like (self .func ):
868+ if self .engine == "numba" :
869+ raise NotImplementedError (
870+ "the 'numba' engine doesn't support lists of callables yet"
871+ )
825872 return self .apply_list_or_dict_like ()
826873
827874 # all empty
@@ -830,17 +877,31 @@ def apply(self) -> DataFrame | Series:
830877
831878 # string dispatch
832879 if isinstance (self .func , str ):
880+ if self .engine == "numba" :
881+ raise NotImplementedError (
882+ "the 'numba' engine doesn't support using "
883+ "a string as the callable function"
884+ )
833885 return self .apply_str ()
834886
835887 # ufunc
836888 elif isinstance (self .func , np .ufunc ):
889+ if self .engine == "numba" :
890+ raise NotImplementedError (
891+ "the 'numba' engine doesn't support "
892+ "using a numpy ufunc as the callable function"
893+ )
837894 with np .errstate (all = "ignore" ):
838895 results = self .obj ._mgr .apply ("apply" , func = self .func )
839896 # _constructor will retain self.index and self.columns
840897 return self .obj ._constructor_from_mgr (results , axes = results .axes )
841898
842899 # broadcasting
843900 if self .result_type == "broadcast" :
901+ if self .engine == "numba" :
902+ raise NotImplementedError (
903+ "the 'numba' engine doesn't support result_type='broadcast'"
904+ )
844905 return self .apply_broadcast (self .obj )
845906
846907 # one axis empty
@@ -997,7 +1058,10 @@ def apply_broadcast(self, target: DataFrame) -> DataFrame:
9971058 return result
9981059
9991060 def apply_standard (self ):
1000- results , res_index = self .apply_series_generator ()
1061+ if self .engine == "python" :
1062+ results , res_index = self .apply_series_generator ()
1063+ else :
1064+ results , res_index = self .apply_series_numba ()
10011065
10021066 # wrap results
10031067 return self .wrap_results (results , res_index )
@@ -1021,6 +1085,19 @@ def apply_series_generator(self) -> tuple[ResType, Index]:
10211085
10221086 return results , res_index
10231087
1088+ def apply_series_numba (self ):
1089+ if self .engine_kwargs .get ("parallel" , False ):
1090+ raise NotImplementedError (
1091+ "Parallel apply is not supported when raw=False and engine='numba'"
1092+ )
1093+ if not self .obj .index .is_unique or not self .columns .is_unique :
1094+ raise NotImplementedError (
1095+ "The index/columns must be unique when raw=False and engine='numba'"
1096+ )
1097+ self .validate_values_for_numba ()
1098+ results = self .apply_with_numba ()
1099+ return results , self .result_index
1100+
10241101 def wrap_results (self , results : ResType , res_index : Index ) -> DataFrame | Series :
10251102 from pandas import Series
10261103
@@ -1060,6 +1137,49 @@ class FrameRowApply(FrameApply):
10601137 def series_generator (self ) -> Generator [Series , None , None ]:
10611138 return (self .obj ._ixs (i , axis = 1 ) for i in range (len (self .columns )))
10621139
1140+ @staticmethod
1141+ @functools .cache
1142+ def generate_numba_apply_func (
1143+ func , nogil = True , nopython = True , parallel = False
1144+ ) -> Callable [[npt .NDArray , Index , Index ], dict [int , Any ]]:
1145+ numba = import_optional_dependency ("numba" )
1146+ from pandas import Series
1147+
1148+ # Import helper from extensions to cast string object -> np strings
1149+ # Note: This also has the side effect of loading our numba extensions
1150+ from pandas .core ._numba .extensions import maybe_cast_str
1151+
1152+ jitted_udf = numba .extending .register_jitable (func )
1153+
1154+ # Currently the parallel argument doesn't get passed through here
1155+ # (it's disabled) since the dicts in numba aren't thread-safe.
1156+ @numba .jit (nogil = nogil , nopython = nopython , parallel = parallel )
1157+ def numba_func (values , col_names , df_index ):
1158+ results = {}
1159+ for j in range (values .shape [1 ]):
1160+ # Create the series
1161+ ser = Series (
1162+ values [:, j ], index = df_index , name = maybe_cast_str (col_names [j ])
1163+ )
1164+ results [j ] = jitted_udf (ser )
1165+ return results
1166+
1167+ return numba_func
1168+
1169+ def apply_with_numba (self ) -> dict [int , Any ]:
1170+ nb_func = self .generate_numba_apply_func (
1171+ cast (Callable , self .func ), ** self .engine_kwargs
1172+ )
1173+ from pandas .core ._numba .extensions import set_numba_data
1174+
1175+ # Convert from numba dict to regular dict
1176+ # Our isinstance checks in the df constructor don't pass for numbas typed dict
1177+ with set_numba_data (self .obj .index ) as index , set_numba_data (
1178+ self .columns
1179+ ) as columns :
1180+ res = dict (nb_func (self .values , columns , index ))
1181+ return res
1182+
10631183 @property
10641184 def result_index (self ) -> Index :
10651185 return self .columns
@@ -1143,6 +1263,52 @@ def series_generator(self) -> Generator[Series, None, None]:
11431263 object .__setattr__ (ser , "_name" , name )
11441264 yield ser
11451265
1266+ @staticmethod
1267+ @functools .cache
1268+ def generate_numba_apply_func (
1269+ func , nogil = True , nopython = True , parallel = False
1270+ ) -> Callable [[npt .NDArray , Index , Index ], dict [int , Any ]]:
1271+ numba = import_optional_dependency ("numba" )
1272+ from pandas import Series
1273+ from pandas .core ._numba .extensions import maybe_cast_str
1274+
1275+ jitted_udf = numba .extending .register_jitable (func )
1276+
1277+ @numba .jit (nogil = nogil , nopython = nopython , parallel = parallel )
1278+ def numba_func (values , col_names_index , index ):
1279+ results = {}
1280+ # Currently the parallel argument doesn't get passed through here
1281+ # (it's disabled) since the dicts in numba aren't thread-safe.
1282+ for i in range (values .shape [0 ]):
1283+ # Create the series
1284+ # TODO: values corrupted without the copy
1285+ ser = Series (
1286+ values [i ].copy (),
1287+ index = col_names_index ,
1288+ name = maybe_cast_str (index [i ]),
1289+ )
1290+ results [i ] = jitted_udf (ser )
1291+
1292+ return results
1293+
1294+ return numba_func
1295+
1296+ def apply_with_numba (self ) -> dict [int , Any ]:
1297+ nb_func = self .generate_numba_apply_func (
1298+ cast (Callable , self .func ), ** self .engine_kwargs
1299+ )
1300+
1301+ from pandas .core ._numba .extensions import set_numba_data
1302+
1303+ # Convert from numba dict to regular dict
1304+ # Our isinstance checks in the df constructor don't pass for numbas typed dict
1305+ with set_numba_data (self .obj .index ) as index , set_numba_data (
1306+ self .columns
1307+ ) as columns :
1308+ res = dict (nb_func (self .values , columns , index ))
1309+
1310+ return res
1311+
11461312 @property
11471313 def result_index (self ) -> Index :
11481314 return self .index
0 commit comments