@@ -31,7 +31,9 @@ cdef class IntervalTree(IntervalMixin):
3131 we are emulating the IndexEngine interface
3232 """
3333 cdef readonly:
34- object left, right, root, dtype
34+ ndarray left, right
35+ IntervalNode root
36+ object dtype
3537 str closed
3638 object _is_overlapping, _left_sorter, _right_sorter
3739
@@ -203,6 +205,41 @@ cdef sort_values_and_indices(all_values, all_indices, subset):
203205# Nodes
204206# ----------------------------------------------------------------------
205207
208+ @cython.internal
209+ cdef class IntervalNode:
210+ cdef readonly:
211+ int64_t n_elements, n_center, leaf_size
212+ bint is_leaf_node
213+
214+ def __repr__(self) -> str:
215+ if self.is_leaf_node:
216+ return (
217+ f"<{type(self).__name__}: {self.n_elements} elements (terminal)>"
218+ )
219+ else:
220+ n_left = self.left_node.n_elements
221+ n_right = self.right_node.n_elements
222+ n_center = self.n_elements - n_left - n_right
223+ return (
224+ f"<{type(self).__name__}: "
225+ f"pivot {self.pivot}, {self.n_elements} elements "
226+ f"({n_left} left, {n_right} right, {n_center} overlapping)>"
227+ )
228+
229+ def counts(self):
230+ """
231+ Inspect counts on this node
232+ useful for debugging purposes
233+ """
234+ if self.is_leaf_node:
235+ return self.n_elements
236+ else:
237+ m = len(self.center_left_values)
238+ l = self.left_node.counts()
239+ r = self.right_node.counts()
240+ return (m, (l, r))
241+
242+
206243# we need specialized nodes and leaves to optimize for different dtype and
207244# closed values
208245
@@ -240,7 +277,7 @@ NODE_CLASSES = {}
240277
241278
242279@cython.internal
243- cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
280+ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode(IntervalNode) :
244281 """Non-terminal node for an IntervalTree
245282
246283 Categorizes intervals by those that fall to the left, those that fall to
@@ -252,8 +289,6 @@ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
252289 int64_t[:] center_left_indices, center_right_indices, indices
253290 {{dtype}}_t min_left, max_right
254291 {{dtype}}_t pivot
255- int64_t n_elements, n_center, leaf_size
256- bint is_leaf_node
257292
258293 def __init__(self,
259294 ndarray[{{dtype}}_t, ndim=1] left,
@@ -381,31 +416,6 @@ cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode:
381416 else:
382417 result.extend(self.center_left_indices)
383418
384- def __repr__(self) -> str:
385- if self.is_leaf_node:
386- return ('<{{dtype_title}}Closed{{closed_title}}IntervalNode: '
387- '%s elements (terminal)>' % self.n_elements)
388- else:
389- n_left = self.left_node.n_elements
390- n_right = self.right_node.n_elements
391- n_center = self.n_elements - n_left - n_right
392- return ('<{{dtype_title}}Closed{{closed_title}}IntervalNode: '
393- 'pivot %s, %s elements (%s left, %s right, %s '
394- 'overlapping)>' % (self.pivot, self.n_elements,
395- n_left, n_right, n_center))
396-
397- def counts(self):
398- """
399- Inspect counts on this node
400- useful for debugging purposes
401- """
402- if self.is_leaf_node:
403- return self.n_elements
404- else:
405- m = len(self.center_left_values)
406- l = self.left_node.counts()
407- r = self.right_node.counts()
408- return (m, (l, r))
409419
410420NODE_CLASSES['{{dtype}}',
411421 '{{closed}}'] = {{dtype_title}}Closed{{closed_title}}IntervalNode
0 commit comments