@@ -20,6 +20,18 @@ ctypedef fused rank_t:
2020 object
2121
2222
23+ cdef inline bint _treat_as_na(rank_t val, bint is_datetimelike) nogil:
24+ if rank_t is object:
25+ # Should never be used, but we need to avoid the `val != val` below
26+ # or else cython will raise about gil acquisition.
27+ raise NotImplementedError
28+
29+ elif rank_t is int64_t:
30+ return is_datetimelike and val == NPY_NAT
31+ else:
32+ return val != val
33+
34+
2335@cython.wraparound(False)
2436@cython.boundscheck(False)
2537def group_last(rank_t[:, :] out,
@@ -61,24 +73,16 @@ def group_last(rank_t[:, :] out,
6173 for j in range(K):
6274 val = values[i, j]
6375
64- # not nan
65- if rank_t is int64_t:
66- # need a special notna check
67- if val != NPY_NAT:
68- nobs[lab, j] += 1
69- resx[lab, j] = val
70- else:
71- if val == val:
72- nobs[lab, j] += 1
73- resx[lab, j] = val
76+ if val == val:
77+ # NB: use _treat_as_na here once
78+ # conditional-nogil is available.
79+ nobs[lab, j] += 1
80+ resx[lab, j] = val
7481
7582 for i in range(ncounts):
7683 for j in range(K):
7784 if nobs[i, j] == 0:
78- if rank_t is int64_t:
79- out[i, j] = NPY_NAT
80- else:
81- out[i, j] = NAN
85+ out[i, j] = NAN
8286 else:
8387 out[i, j] = resx[i, j]
8488 else:
@@ -92,16 +96,10 @@ def group_last(rank_t[:, :] out,
9296 for j in range(K):
9397 val = values[i, j]
9498
95- # not nan
96- if rank_t is int64_t:
97- # need a special notna check
98- if val != NPY_NAT:
99- nobs[lab, j] += 1
100- resx[lab, j] = val
101- else:
102- if val == val:
103- nobs[lab, j] += 1
104- resx[lab, j] = val
99+ if not _treat_as_na(val, True):
100+ # TODO: Sure we always want is_datetimelike=True?
101+ nobs[lab, j] += 1
102+ resx[lab, j] = val
105103
106104 for i in range(ncounts):
107105 for j in range(K):
@@ -113,6 +111,7 @@ def group_last(rank_t[:, :] out,
113111 break
114112 else:
115113 out[i, j] = NAN
114+
116115 else:
117116 out[i, j] = resx[i, j]
118117
@@ -121,7 +120,6 @@ def group_last(rank_t[:, :] out,
121120 # block.
122121 raise RuntimeError("empty group with uint64_t")
123122
124-
125123group_last_float64 = group_last["float64_t"]
126124group_last_float32 = group_last["float32_t"]
127125group_last_int64 = group_last["int64_t"]
@@ -169,8 +167,9 @@ def group_nth(rank_t[:, :] out,
169167 for j in range(K):
170168 val = values[i, j]
171169
172- # not nan
173170 if val == val:
171+ # NB: use _treat_as_na here once
172+ # conditional-nogil is available.
174173 nobs[lab, j] += 1
175174 if nobs[lab, j] == rank:
176175 resx[lab, j] = val
@@ -193,18 +192,11 @@ def group_nth(rank_t[:, :] out,
193192 for j in range(K):
194193 val = values[i, j]
195194
196- # not nan
197- if rank_t is int64_t:
198- # need a special notna check
199- if val != NPY_NAT:
200- nobs[lab, j] += 1
201- if nobs[lab, j] == rank:
202- resx[lab, j] = val
203- else:
204- if val == val:
205- nobs[lab, j] += 1
206- if nobs[lab, j] == rank:
207- resx[lab, j] = val
195+ if not _treat_as_na(val, True):
196+ # TODO: Sure we always want is_datetimelike=True?
197+ nobs[lab, j] += 1
198+ if nobs[lab, j] == rank:
199+ resx[lab, j] = val
208200
209201 for i in range(ncounts):
210202 for j in range(K):
@@ -487,17 +479,11 @@ def group_max(groupby_t[:, :] out,
487479 for j in range(K):
488480 val = values[i, j]
489481
490- # not nan
491- if groupby_t is int64_t:
492- if val != nan_val:
493- nobs[lab, j] += 1
494- if val > maxx[lab, j]:
495- maxx[lab, j] = val
496- else:
497- if val == val:
498- nobs[lab, j] += 1
499- if val > maxx[lab, j]:
500- maxx[lab, j] = val
482+ if not _treat_as_na(val, True):
483+ # TODO: Sure we always want is_datetimelike=True?
484+ nobs[lab, j] += 1
485+ if val > maxx[lab, j]:
486+ maxx[lab, j] = val
501487
502488 for i in range(ncounts):
503489 for j in range(K):
@@ -563,17 +549,11 @@ def group_min(groupby_t[:, :] out,
563549 for j in range(K):
564550 val = values[i, j]
565551
566- # not nan
567- if groupby_t is int64_t:
568- if val != nan_val:
569- nobs[lab, j] += 1
570- if val < minx[lab, j]:
571- minx[lab, j] = val
572- else:
573- if val == val:
574- nobs[lab, j] += 1
575- if val < minx[lab, j]:
576- minx[lab, j] = val
552+ if not _treat_as_na(val, True):
553+ # TODO: Sure we always want is_datetimelike=True?
554+ nobs[lab, j] += 1
555+ if val < minx[lab, j]:
556+ minx[lab, j] = val
577557
578558 for i in range(ncounts):
579559 for j in range(K):
@@ -643,21 +623,13 @@ def group_cummin(groupby_t[:, :] out,
643623 for j in range(K):
644624 val = values[i, j]
645625
646- # val = nan
647- if groupby_t is int64_t:
648- if is_datetimelike and val == NPY_NAT:
649- out[i, j] = NPY_NAT
650- else:
651- mval = accum[lab, j]
652- if val < mval:
653- accum[lab, j] = mval = val
654- out[i, j] = mval
626+ if _treat_as_na(val, is_datetimelike):
627+ out[i, j] = val
655628 else:
656- if val == val:
657- mval = accum[lab, j]
658- if val < mval:
659- accum[lab, j] = mval = val
660- out[i, j] = mval
629+ mval = accum[lab, j]
630+ if val < mval:
631+ accum[lab, j] = mval = val
632+ out[i, j] = mval
661633
662634
663635@cython.boundscheck(False)
@@ -712,17 +684,10 @@ def group_cummax(groupby_t[:, :] out,
712684 for j in range(K):
713685 val = values[i, j]
714686
715- if groupby_t is int64_t:
716- if is_datetimelike and val == NPY_NAT:
717- out[i, j] = NPY_NAT
718- else:
719- mval = accum[lab, j]
720- if val > mval:
721- accum[lab, j] = mval = val
722- out[i, j] = mval
687+ if _treat_as_na(val, is_datetimelike):
688+ out[i, j] = val
723689 else:
724- if val == val:
725- mval = accum[lab, j]
726- if val > mval:
727- accum[lab, j] = mval = val
728- out[i, j] = mval
690+ mval = accum[lab, j]
691+ if val > mval:
692+ accum[lab, j] = mval = val
693+ out[i, j] = mval
0 commit comments