diff --git a/pandas/_libs/intervaltree.pxi.in b/pandas/_libs/intervaltree.pxi.in index 0fb01a2188a57..55d67f000f93a 100644 --- a/pandas/_libs/intervaltree.pxi.in +++ b/pandas/_libs/intervaltree.pxi.in @@ -31,7 +31,9 @@ cdef class IntervalTree(IntervalMixin): we are emulating the IndexEngine interface """ cdef readonly: - object left, right, root, dtype + ndarray left, right + IntervalNode root + object dtype str closed object _is_overlapping, _left_sorter, _right_sorter @@ -203,6 +205,41 @@ cdef sort_values_and_indices(all_values, all_indices, subset): # Nodes # ---------------------------------------------------------------------- +@cython.internal +cdef class IntervalNode: + cdef readonly: + int64_t n_elements, n_center, leaf_size + bint is_leaf_node + + def __repr__(self) -> str: + if self.is_leaf_node: + return ( + f"<{type(self).__name__}: {self.n_elements} elements (terminal)>" + ) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ( + f"<{type(self).__name__}: " + f"pivot {self.pivot}, {self.n_elements} elements " + f"({n_left} left, {n_right} right, {n_center} overlapping)>" + ) + + def counts(self): + """ + Inspect counts on this node + useful for debugging purposes + """ + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + + # we need specialized nodes and leaves to optimize for different dtype and # closed values @@ -240,7 +277,7 @@ NODE_CLASSES = {} @cython.internal -cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode: +cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode(IntervalNode): """Non-terminal node for an IntervalTree Categorizes intervals by those that fall to the left, those that fall to @@ -252,8 +289,6 @@ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode: int64_t[:] center_left_indices, center_right_indices, indices {{dtype}}_t min_left, max_right {{dtype}}_t pivot - int64_t n_elements, n_center, leaf_size - bint is_leaf_node def __init__(self, ndarray[{{dtype}}_t, ndim=1] left, @@ -381,31 +416,6 @@ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode: else: result.extend(self.center_left_indices) - def __repr__(self) -> str: - if self.is_leaf_node: - return ('<{{dtype_title}}Closed{{closed_title}}IntervalNode: ' - '%s elements (terminal)>' % self.n_elements) - else: - n_left = self.left_node.n_elements - n_right = self.right_node.n_elements - n_center = self.n_elements - n_left - n_right - return ('<{{dtype_title}}Closed{{closed_title}}IntervalNode: ' - 'pivot %s, %s elements (%s left, %s right, %s ' - 'overlapping)>' % (self.pivot, self.n_elements, - n_left, n_right, n_center)) - - def counts(self): - """ - Inspect counts on this node - useful for debugging purposes - """ - if self.is_leaf_node: - return self.n_elements - else: - m = len(self.center_left_values) - l = self.left_node.counts() - r = self.right_node.counts() - return (m, (l, r)) NODE_CLASSES['{{dtype}}', '{{closed}}'] = {{dtype_title}}Closed{{closed_title}}IntervalNode