428 lines
14 KiB
Cython
428 lines
14 KiB
Cython
"""
|
|
Template for intervaltree
|
|
|
|
WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
|
|
"""
|
|
|
|
from pandas._libs.algos import is_monotonic
|
|
|
|
ctypedef fused int_scalar_t:
|
|
int64_t
|
|
float64_t
|
|
|
|
ctypedef fused uint_scalar_t:
|
|
uint64_t
|
|
float64_t
|
|
|
|
ctypedef fused scalar_t:
|
|
int_scalar_t
|
|
uint_scalar_t
|
|
|
|
# ----------------------------------------------------------------------
|
|
# IntervalTree
|
|
# ----------------------------------------------------------------------
|
|
|
|
cdef class IntervalTree(IntervalMixin):
|
|
"""A centered interval tree
|
|
|
|
Based off the algorithm described on Wikipedia:
|
|
https://en.wikipedia.org/wiki/Interval_tree
|
|
|
|
we are emulating the IndexEngine interface
|
|
"""
|
|
cdef readonly:
|
|
ndarray left, right
|
|
IntervalNode root
|
|
object dtype
|
|
str closed
|
|
object _is_overlapping, _left_sorter, _right_sorter
|
|
Py_ssize_t _na_count
|
|
|
|
def __init__(self, left, right, closed='right', leaf_size=100):
|
|
"""
|
|
Parameters
|
|
----------
|
|
left, right : np.ndarray[ndim=1]
|
|
Left and right bounds for each interval. Assumed to contain no
|
|
NaNs.
|
|
closed : {'left', 'right', 'both', 'neither'}, optional
|
|
Whether the intervals are closed on the left-side, right-side, both
|
|
or neither. Defaults to 'right'.
|
|
leaf_size : int, optional
|
|
Parameter that controls when the tree switches from creating nodes
|
|
to brute-force search. Tune this parameter to optimize query
|
|
performance.
|
|
"""
|
|
if closed not in ['left', 'right', 'both', 'neither']:
|
|
raise ValueError("invalid option for 'closed': %s" % closed)
|
|
|
|
left = np.asarray(left)
|
|
right = np.asarray(right)
|
|
self.dtype = np.result_type(left, right)
|
|
self.left = np.asarray(left, dtype=self.dtype)
|
|
self.right = np.asarray(right, dtype=self.dtype)
|
|
|
|
indices = np.arange(len(left), dtype='int64')
|
|
|
|
self.closed = closed
|
|
|
|
# GH 23352: ensure no nan in nodes
|
|
mask = ~np.isnan(self.left)
|
|
self._na_count = len(mask) - mask.sum()
|
|
self.left = self.left[mask]
|
|
self.right = self.right[mask]
|
|
indices = indices[mask]
|
|
|
|
node_cls = NODE_CLASSES[str(self.dtype), closed]
|
|
self.root = node_cls(self.left, self.right, indices, leaf_size)
|
|
|
|
@property
|
|
def left_sorter(self) -> np.ndarray:
|
|
"""How to sort the left labels; this is used for binary search
|
|
"""
|
|
if self._left_sorter is None:
|
|
self._left_sorter = np.argsort(self.left)
|
|
return self._left_sorter
|
|
|
|
@property
|
|
def right_sorter(self) -> np.ndarray:
|
|
"""How to sort the right labels
|
|
"""
|
|
if self._right_sorter is None:
|
|
self._right_sorter = np.argsort(self.right)
|
|
return self._right_sorter
|
|
|
|
@property
|
|
def is_overlapping(self) -> bool:
|
|
"""
|
|
Determine if the IntervalTree contains overlapping intervals.
|
|
Cached as self._is_overlapping.
|
|
"""
|
|
if self._is_overlapping is not None:
|
|
return self._is_overlapping
|
|
|
|
# <= when both sides closed since endpoints can overlap
|
|
op = le if self.closed == 'both' else lt
|
|
|
|
# overlap if start of current interval < end of previous interval
|
|
# (current and previous in terms of sorted order by left/start side)
|
|
current = self.left[self.left_sorter[1:]]
|
|
previous = self.right[self.left_sorter[:-1]]
|
|
self._is_overlapping = bool(op(current, previous).any())
|
|
|
|
return self._is_overlapping
|
|
|
|
@property
|
|
def is_monotonic_increasing(self) -> bool:
|
|
"""
|
|
Return True if the IntervalTree is monotonic increasing (only equal or
|
|
increasing values), else False
|
|
"""
|
|
if self._na_count > 0:
|
|
return False
|
|
values = [self.right, self.left]
|
|
|
|
sort_order = np.lexsort(values)
|
|
return is_monotonic(sort_order, False)[0]
|
|
|
|
def get_indexer(self, scalar_t[:] target) -> np.ndarray:
|
|
"""Return the positions corresponding to unique intervals that overlap
|
|
with the given array of scalar targets.
|
|
"""
|
|
|
|
# TODO: write get_indexer_intervals
|
|
cdef:
|
|
Py_ssize_t old_len
|
|
Py_ssize_t i
|
|
Int64Vector result
|
|
|
|
result = Int64Vector()
|
|
old_len = 0
|
|
for i in range(len(target)):
|
|
try:
|
|
self.root.query(result, target[i])
|
|
except OverflowError:
|
|
# overflow -> no match, which is already handled below
|
|
pass
|
|
|
|
if result.data.n == old_len:
|
|
result.append(-1)
|
|
elif result.data.n > old_len + 1:
|
|
raise KeyError(
|
|
'indexer does not intersect a unique set of intervals')
|
|
old_len = result.data.n
|
|
return result.to_array().astype('intp')
|
|
|
|
def get_indexer_non_unique(self, scalar_t[:] target):
|
|
"""Return the positions corresponding to intervals that overlap with
|
|
the given array of scalar targets. Non-unique positions are repeated.
|
|
"""
|
|
cdef:
|
|
Py_ssize_t old_len
|
|
Py_ssize_t i
|
|
Int64Vector result, missing
|
|
|
|
result = Int64Vector()
|
|
missing = Int64Vector()
|
|
old_len = 0
|
|
for i in range(len(target)):
|
|
try:
|
|
self.root.query(result, target[i])
|
|
except OverflowError:
|
|
# overflow -> no match, which is already handled below
|
|
pass
|
|
|
|
if result.data.n == old_len:
|
|
result.append(-1)
|
|
missing.append(i)
|
|
old_len = result.data.n
|
|
return (result.to_array().astype('intp'),
|
|
missing.to_array().astype('intp'))
|
|
|
|
def __repr__(self) -> str:
|
|
return ('<IntervalTree[{dtype},{closed}]: '
|
|
'{n_elements} elements>'.format(
|
|
dtype=self.dtype, closed=self.closed,
|
|
n_elements=self.root.n_elements))
|
|
|
|
# compat with IndexEngine interface
|
|
def clear_mapping(self) -> None:
|
|
pass
|
|
|
|
|
|
cdef take(ndarray source, ndarray indices):
|
|
"""Take the given positions from a 1D ndarray
|
|
"""
|
|
return PyArray_Take(source, indices, 0)
|
|
|
|
|
|
cdef sort_values_and_indices(all_values, all_indices, subset):
|
|
indices = take(all_indices, subset)
|
|
values = take(all_values, subset)
|
|
sorter = PyArray_ArgSort(values, 0, NPY_QUICKSORT)
|
|
sorted_values = take(values, sorter)
|
|
sorted_indices = take(indices, sorter)
|
|
return sorted_values, sorted_indices
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Nodes
|
|
# ----------------------------------------------------------------------
|
|
|
|
@cython.internal
|
|
cdef class IntervalNode:
|
|
cdef readonly:
|
|
int64_t n_elements, n_center, leaf_size
|
|
bint is_leaf_node
|
|
|
|
def __repr__(self) -> str:
|
|
if self.is_leaf_node:
|
|
return (
|
|
f"<{type(self).__name__}: {self.n_elements} elements (terminal)>"
|
|
)
|
|
else:
|
|
n_left = self.left_node.n_elements
|
|
n_right = self.right_node.n_elements
|
|
n_center = self.n_elements - n_left - n_right
|
|
return (
|
|
f"<{type(self).__name__}: "
|
|
f"pivot {self.pivot}, {self.n_elements} elements "
|
|
f"({n_left} left, {n_right} right, {n_center} overlapping)>"
|
|
)
|
|
|
|
def counts(self):
|
|
"""
|
|
Inspect counts on this node
|
|
useful for debugging purposes
|
|
"""
|
|
if self.is_leaf_node:
|
|
return self.n_elements
|
|
else:
|
|
m = len(self.center_left_values)
|
|
l = self.left_node.counts()
|
|
r = self.right_node.counts()
|
|
return (m, (l, r))
|
|
|
|
|
|
# we need specialized nodes and leaves to optimize for different dtype and
|
|
# closed values
|
|
|
|
{{py:
|
|
|
|
nodes = []
|
|
for dtype in ['float64', 'int64', 'uint64']:
|
|
for closed, cmp_left, cmp_right in [
|
|
('left', '<=', '<'),
|
|
('right', '<', '<='),
|
|
('both', '<=', '<='),
|
|
('neither', '<', '<')]:
|
|
cmp_left_converse = '<' if cmp_left == '<=' else '<='
|
|
cmp_right_converse = '<' if cmp_right == '<=' else '<='
|
|
if dtype.startswith('int'):
|
|
fused_prefix = 'int_'
|
|
elif dtype.startswith('uint'):
|
|
fused_prefix = 'uint_'
|
|
elif dtype.startswith('float'):
|
|
fused_prefix = ''
|
|
nodes.append((dtype, dtype.title(),
|
|
closed, closed.title(),
|
|
cmp_left,
|
|
cmp_right,
|
|
cmp_left_converse,
|
|
cmp_right_converse,
|
|
fused_prefix))
|
|
|
|
}}
|
|
|
|
NODE_CLASSES = {}
|
|
|
|
{{for dtype, dtype_title, closed, closed_title, cmp_left, cmp_right,
|
|
cmp_left_converse, cmp_right_converse, fused_prefix in nodes}}
|
|
|
|
|
|
@cython.internal
|
|
cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode(IntervalNode):
|
|
"""Non-terminal node for an IntervalTree
|
|
|
|
Categorizes intervals by those that fall to the left, those that fall to
|
|
the right, and those that overlap with the pivot.
|
|
"""
|
|
cdef readonly:
|
|
{{dtype_title}}Closed{{closed_title}}IntervalNode left_node, right_node
|
|
{{dtype}}_t[:] center_left_values, center_right_values, left, right
|
|
int64_t[:] center_left_indices, center_right_indices, indices
|
|
{{dtype}}_t min_left, max_right
|
|
{{dtype}}_t pivot
|
|
|
|
def __init__(self,
|
|
ndarray[{{dtype}}_t, ndim=1] left,
|
|
ndarray[{{dtype}}_t, ndim=1] right,
|
|
ndarray[int64_t, ndim=1] indices,
|
|
int64_t leaf_size):
|
|
|
|
self.n_elements = len(left)
|
|
self.leaf_size = leaf_size
|
|
|
|
# min_left and min_right are used to speed-up query by skipping
|
|
# query on sub-nodes. If this node has size 0, query is cheap,
|
|
# so these values don't matter.
|
|
if left.size > 0:
|
|
self.min_left = left.min()
|
|
self.max_right = right.max()
|
|
else:
|
|
self.min_left = 0
|
|
self.max_right = 0
|
|
|
|
if self.n_elements <= leaf_size:
|
|
# make this a terminal (leaf) node
|
|
self.is_leaf_node = True
|
|
self.left = left
|
|
self.right = right
|
|
self.indices = indices
|
|
self.n_center = 0
|
|
else:
|
|
# calculate a pivot so we can create child nodes
|
|
self.is_leaf_node = False
|
|
self.pivot = np.median(left / 2 + right / 2)
|
|
left_set, right_set, center_set = self.classify_intervals(
|
|
left, right)
|
|
|
|
self.left_node = self.new_child_node(left, right,
|
|
indices, left_set)
|
|
self.right_node = self.new_child_node(left, right,
|
|
indices, right_set)
|
|
|
|
self.center_left_values, self.center_left_indices = \
|
|
sort_values_and_indices(left, indices, center_set)
|
|
self.center_right_values, self.center_right_indices = \
|
|
sort_values_and_indices(right, indices, center_set)
|
|
self.n_center = len(self.center_left_indices)
|
|
|
|
@cython.wraparound(False)
|
|
@cython.boundscheck(False)
|
|
cdef classify_intervals(self, {{dtype}}_t[:] left, {{dtype}}_t[:] right):
|
|
"""Classify the given intervals based upon whether they fall to the
|
|
left, right, or overlap with this node's pivot.
|
|
"""
|
|
cdef:
|
|
Int64Vector left_ind, right_ind, overlapping_ind
|
|
Py_ssize_t i
|
|
|
|
left_ind = Int64Vector()
|
|
right_ind = Int64Vector()
|
|
overlapping_ind = Int64Vector()
|
|
|
|
for i in range(self.n_elements):
|
|
if right[i] {{cmp_right_converse}} self.pivot:
|
|
left_ind.append(i)
|
|
elif self.pivot {{cmp_left_converse}} left[i]:
|
|
right_ind.append(i)
|
|
else:
|
|
overlapping_ind.append(i)
|
|
|
|
return (left_ind.to_array(),
|
|
right_ind.to_array(),
|
|
overlapping_ind.to_array())
|
|
|
|
cdef new_child_node(self,
|
|
ndarray[{{dtype}}_t, ndim=1] left,
|
|
ndarray[{{dtype}}_t, ndim=1] right,
|
|
ndarray[int64_t, ndim=1] indices,
|
|
ndarray[int64_t, ndim=1] subset):
|
|
"""Create a new child node.
|
|
"""
|
|
left = take(left, subset)
|
|
right = take(right, subset)
|
|
indices = take(indices, subset)
|
|
return {{dtype_title}}Closed{{closed_title}}IntervalNode(
|
|
left, right, indices, self.leaf_size)
|
|
|
|
@cython.wraparound(False)
|
|
@cython.boundscheck(False)
|
|
@cython.initializedcheck(False)
|
|
cpdef query(self, Int64Vector result, {{fused_prefix}}scalar_t point):
|
|
"""Recursively query this node and its sub-nodes for intervals that
|
|
overlap with the query point.
|
|
"""
|
|
cdef:
|
|
int64_t[:] indices
|
|
{{dtype}}_t[:] values
|
|
Py_ssize_t i
|
|
|
|
if self.is_leaf_node:
|
|
# Once we get down to a certain size, it doesn't make sense to
|
|
# continue the binary tree structure. Instead, we use linear
|
|
# search.
|
|
for i in range(self.n_elements):
|
|
if self.left[i] {{cmp_left}} point {{cmp_right}} self.right[i]:
|
|
result.append(self.indices[i])
|
|
else:
|
|
# There are child nodes. Based on comparing our query to the pivot,
|
|
# look at the center values, then go to the relevant child.
|
|
if point < self.pivot:
|
|
values = self.center_left_values
|
|
indices = self.center_left_indices
|
|
for i in range(self.n_center):
|
|
if not values[i] {{cmp_left}} point:
|
|
break
|
|
result.append(indices[i])
|
|
if point {{cmp_right}} self.left_node.max_right:
|
|
self.left_node.query(result, point)
|
|
elif point > self.pivot:
|
|
values = self.center_right_values
|
|
indices = self.center_right_indices
|
|
for i in range(self.n_center - 1, -1, -1):
|
|
if not point {{cmp_right}} values[i]:
|
|
break
|
|
result.append(indices[i])
|
|
if self.right_node.min_left {{cmp_left}} point:
|
|
self.right_node.query(result, point)
|
|
else:
|
|
result.extend(self.center_left_indices)
|
|
|
|
|
|
NODE_CLASSES['{{dtype}}',
|
|
'{{closed}}'] = {{dtype_title}}Closed{{closed_title}}IntervalNode
|
|
|
|
{{endfor}}
|