33# don't introduce a pandas/pandas.compat import
44# or we get a bootstrapping problem
55from StringIO import StringIO
6- import os
76
87header = """
98cimport numpy as np
3433ctypedef unsigned char UChar
3534
3635cimport util
37- from util cimport is_array, _checknull, _checknan
36+ from util cimport is_array, _checknull, _checknan, get_nat
37+
38+ cdef int64_t iNaT = get_nat()
3839
3940# import datetime C API
4041PyDateTime_IMPORT
@@ -1150,6 +1151,79 @@ def group_var_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
11501151 (ct * ct - ct))
11511152"""
11521153
1154+ group_count_template = """@cython.boundscheck(False)
1155+ @cython.wraparound(False)
1156+ def group_count_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1157+ ndarray[int64_t] counts,
1158+ ndarray[%(c_type)s, ndim=2] values,
1159+ ndarray[int64_t] labels):
1160+ '''
1161+ Only aggregates on axis=0
1162+ '''
1163+ cdef:
1164+ Py_ssize_t i, j, lab
1165+ Py_ssize_t N = values.shape[0], K = values.shape[1]
1166+ %(c_type)s val
1167+ ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
1168+ dtype=np.int64)
1169+
1170+ if len(values) != len(labels):
1171+ raise AssertionError("len(index) != len(labels)")
1172+
1173+ for i in range(N):
1174+ lab = labels[i]
1175+ if lab < 0:
1176+ continue
1177+
1178+ counts[lab] += 1
1179+ for j in range(K):
1180+ val = values[i, j]
1181+
1182+ # not nan
1183+ nobs[lab, j] += val == val and val != iNaT
1184+
1185+ for i in range(len(counts)):
1186+ for j in range(K):
1187+ out[i, j] = nobs[i, j]
1188+
1189+
1190+ """
1191+
1192+ group_count_bin_template = """@cython.boundscheck(False)
1193+ @cython.wraparound(False)
1194+ def group_count_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1195+ ndarray[int64_t] counts,
1196+ ndarray[%(c_type)s, ndim=2] values,
1197+ ndarray[int64_t] bins):
1198+ '''
1199+ Only aggregates on axis=0
1200+ '''
1201+ cdef:
1202+ Py_ssize_t i, j, ngroups
1203+ Py_ssize_t N = values.shape[0], K = values.shape[1], b = 0
1204+ %(c_type)s val
1205+ ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
1206+ dtype=np.int64)
1207+
1208+ ngroups = len(bins) + (bins[len(bins) - 1] != N)
1209+
1210+ for i in range(N):
1211+ while b < ngroups - 1 and i >= bins[b]:
1212+ b += 1
1213+
1214+ counts[b] += 1
1215+ for j in range(K):
1216+ val = values[i, j]
1217+
1218+ # not nan
1219+ nobs[b, j] += val == val and val != iNaT
1220+
1221+ for i in range(ngroups):
1222+ for j in range(K):
1223+ out[i, j] = nobs[i, j]
1224+
1225+
1226+ """
11531227# add passing bin edges, instead of labels
11541228
11551229
@@ -2145,7 +2219,8 @@ def put2d_%(name)s_%(dest_type)s(ndarray[%(c_type)s, ndim=2, cast=True] values,
21452219#-------------------------------------------------------------------------
21462220# Generators
21472221
2148- def generate_put_template (template , use_ints = True , use_floats = True ):
2222+ def generate_put_template (template , use_ints = True , use_floats = True ,
2223+ use_objects = False ):
21492224 floats_list = [
21502225 ('float64' , 'float64_t' , 'float64_t' , 'np.float64' ),
21512226 ('float32' , 'float32_t' , 'float32_t' , 'np.float32' ),
@@ -2156,11 +2231,14 @@ def generate_put_template(template, use_ints = True, use_floats = True):
21562231 ('int32' , 'int32_t' , 'float64_t' , 'np.float64' ),
21572232 ('int64' , 'int64_t' , 'float64_t' , 'np.float64' ),
21582233 ]
2234+ object_list = [('object' , 'object' , 'float64_t' , 'np.float64' )]
21592235 function_list = []
21602236 if use_floats :
21612237 function_list .extend (floats_list )
21622238 if use_ints :
21632239 function_list .extend (ints_list )
2240+ if use_objects :
2241+ function_list .extend (object_list )
21642242
21652243 output = StringIO ()
21662244 for name , c_type , dest_type , dest_dtype in function_list :
@@ -2251,6 +2329,8 @@ def generate_from_template(template, exclude=None):
22512329 group_max_bin_template ,
22522330 group_ohlc_template ]
22532331
2332+ groupby_count = [group_count_template , group_count_bin_template ]
2333+
22542334templates_1d = [map_indices_template ,
22552335 pad_template ,
22562336 backfill_template ,
@@ -2272,6 +2352,7 @@ def generate_from_template(template, exclude=None):
22722352 take_2d_axis1_template ,
22732353 take_2d_multi_template ]
22742354
2355+
22752356def generate_take_cython_file (path = 'generated.pyx' ):
22762357 with open (path , 'w' ) as f :
22772358 print (header , file = f )
@@ -2288,7 +2369,10 @@ def generate_take_cython_file(path='generated.pyx'):
22882369 print (generate_put_template (template ), file = f )
22892370
22902371 for template in groupbys :
2291- print (generate_put_template (template , use_ints = False ), file = f )
2372+ print (generate_put_template (template , use_ints = False ), file = f )
2373+
2374+ for template in groupby_count :
2375+ print (generate_put_template (template , use_objects = True ), file = f )
22922376
22932377 # for template in templates_1d_datetime:
22942378 # print >> f, generate_from_template_datetime(template)
@@ -2299,5 +2383,6 @@ def generate_take_cython_file(path='generated.pyx'):
22992383 for template in nobool_1d_templates :
23002384 print (generate_from_template (template , exclude = ['bool' ]), file = f )
23012385
2386+
23022387if __name__ == '__main__' :
23032388 generate_take_cython_file ()
0 commit comments