3333ctypedef unsigned char UChar
3434
3535cimport util
36- from util cimport is_array, _checknull, _checknan
36+ from util cimport is_array, _checknull, _checknan, get_nat
37+
38+ cdef int64_t iNaT = get_nat()
3739
3840# import datetime C API
3941PyDateTime_IMPORT
@@ -1159,16 +1161,15 @@ def group_count_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
11591161 Only aggregates on axis=0
11601162 '''
11611163 cdef:
1162- Py_ssize_t i, j, N, K, lab
1163- %(dest_type2)s val
1164- ndarray[%(dest_type2)s, ndim=2] nobs = np.zeros_like(out)
1165-
1164+ Py_ssize_t i, j, lab
1165+ Py_ssize_t N = values.shape[0], K = values.shape[1]
1166+ %(c_type)s val
1167+ ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
1168+ dtype=np.int64)
11661169
1167- if not len(values) = = len(labels):
1170+ if len(values) ! = len(labels):
11681171 raise AssertionError("len(index) != len(labels)")
11691172
1170- N, K = (<object> values).shape
1171-
11721173 for i in range(N):
11731174 lab = labels[i]
11741175 if lab < 0:
@@ -1179,7 +1180,7 @@ def group_count_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
11791180 val = values[i, j]
11801181
11811182 # not nan
1182- nobs[lab, j] += val == val
1183+ nobs[lab, j] += val == val and val != iNaT
11831184
11841185 for i in range(len(counts)):
11851186 for j in range(K):
@@ -1198,20 +1199,14 @@ def group_count_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
11981199 Only aggregates on axis=0
11991200 '''
12001201 cdef:
1201- Py_ssize_t i, j, N, K, ngroups, b
1202- %(dest_type2)s val, count
1203- ndarray[%(dest_type2)s, ndim=2] nobs
1204-
1205- nobs = np.zeros_like(out )
1202+ Py_ssize_t i, j, ngroups
1203+ Py_ssize_t N = values.shape[0], K = values.shape[1], b = 0
1204+ %(c_type)s val
1205+ ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
1206+ dtype= np.int64 )
12061207
1207- if bins[len(bins) - 1] == len(values):
1208- ngroups = len(bins)
1209- else:
1210- ngroups = len(bins) + 1
1208+ ngroups = len(bins) + (bins[len(bins) - 1] != N)
12111209
1212- N, K = (<object> values).shape
1213-
1214- b = 0
12151210 for i in range(N):
12161211 while b < ngroups - 1 and i >= bins[b]:
12171212 b += 1
@@ -1221,7 +1216,7 @@ def group_count_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
12211216 val = values[i, j]
12221217
12231218 # not nan
1224- nobs[b, j] += val == val
1219+ nobs[b, j] += val == val and val != iNaT
12251220
12261221 for i in range(ngroups):
12271222 for j in range(K):
@@ -2224,7 +2219,8 @@ def put2d_%(name)s_%(dest_type)s(ndarray[%(c_type)s, ndim=2, cast=True] values,
22242219#-------------------------------------------------------------------------
22252220# Generators
22262221
2227- def generate_put_template (template , use_ints = True , use_floats = True ):
2222+ def generate_put_template (template , use_ints = True , use_floats = True ,
2223+ use_objects = False ):
22282224 floats_list = [
22292225 ('float64' , 'float64_t' , 'float64_t' , 'np.float64' ),
22302226 ('float32' , 'float32_t' , 'float32_t' , 'np.float32' ),
@@ -2235,11 +2231,14 @@ def generate_put_template(template, use_ints = True, use_floats = True):
22352231 ('int32' , 'int32_t' , 'float64_t' , 'np.float64' ),
22362232 ('int64' , 'int64_t' , 'float64_t' , 'np.float64' ),
22372233 ]
2234+ object_list = [('object' , 'object' , 'float64_t' , 'np.float64' )]
22382235 function_list = []
22392236 if use_floats :
22402237 function_list .extend (floats_list )
22412238 if use_ints :
22422239 function_list .extend (ints_list )
2240+ if use_objects :
2241+ function_list .extend (object_list )
22432242
22442243 output = StringIO ()
22452244 for name , c_type , dest_type , dest_dtype in function_list :
@@ -2373,7 +2372,7 @@ def generate_take_cython_file(path='generated.pyx'):
23732372 print (generate_put_template (template , use_ints = False ), file = f )
23742373
23752374 for template in groupby_count :
2376- print (generate_put_template (template ), file = f )
2375+ print (generate_put_template (template , use_objects = True ), file = f )
23772376
23782377 # for template in templates_1d_datetime:
23792378 # print >> f, generate_from_template_datetime(template)
0 commit comments