diff --git a/IntervalTree.h b/IntervalTree.h index 6197a8a..1bf2c20 100644 --- a/IntervalTree.h +++ b/IntervalTree.h @@ -5,207 +5,333 @@ #include #include #include +#include -template +template class Interval { public: - K start; - K stop; - T value; - Interval(K s, K e, const T& v) - : start(s) - , stop(e) - , value(v) - { } + Scalar start; + Scalar stop; + Value value; + Interval(const Scalar& s, const Scalar& e, const Value& v) + : start(std::min(s, e)) + , stop(std::max(s, e)) + , value(v) + {} }; -template -K intervalStart(const Interval& i) { +template +Value intervalStart(const Interval& i) { return i.start; } -template -K intervalStop(const Interval& i) { +template +Value intervalStop(const Interval& i) { return i.stop; } -template - std::ostream& operator<<(std::ostream& out, Interval& i) { +template +std::ostream& operator<<(std::ostream& out, const Interval& i) { out << "Interval(" << i.start << ", " << i.stop << "): " << i.value; return out; } -template -class IntervalStartSorter { +template +class IntervalTree { public: - bool operator() (const Interval& a, const Interval& b) { - return a.start < b.start; - } -}; + typedef Interval interval; + typedef std::vector interval_vector; -template -class IntervalTree { -public: - typedef Interval interval; - typedef std::vector intervalVector; - typedef IntervalTree intervalTree; + struct IntervalStartCmp { + bool operator()(const interval& a, const interval& b) { + return a.start < b.start; + } + }; - intervalVector intervals; - std::unique_ptr left; - std::unique_ptr right; - K center; + struct IntervalStopCmp { + bool operator()(const interval& a, const interval& b) { + return a.stop < b.stop; + } + }; - IntervalTree(void) + IntervalTree() : left(nullptr) , right(nullptr) , center(0) - { } + {} -private: - std::unique_ptr copyTree(const intervalTree& orig){ - return std::unique_ptr(new intervalTree(orig)); + ~IntervalTree() = default; + + std::unique_ptr clone() const { + return std::unique_ptr(new IntervalTree(*this)); } -public: - IntervalTree(const intervalTree& other) + IntervalTree(const IntervalTree& other) : intervals(other.intervals), - left(other.left ? copyTree(*other.left) : nullptr), - right(other.right ? copyTree(*other.right) : nullptr), + left(other.left ? other.left->clone() : nullptr), + right(other.right ? other.right->clone() : nullptr), center(other.center) - { - } + {} -public: + IntervalTree& operator=(IntervalTree&&) = default; + IntervalTree(IntervalTree&&) = default; - IntervalTree& operator=(const intervalTree& other) { + IntervalTree& operator=(const IntervalTree& other) { center = other.center; intervals = other.intervals; - left = other.left ? copyTree(*other.left) : nullptr; - right = other.right ? copyTree(*other.right) : nullptr; + left = other.left ? other.left->clone() : nullptr; + right = other.right ? other.right->clone() : nullptr; return *this; } - // Note: changes the order of ivals - IntervalTree( - intervalVector& ivals, + IntervalTree( + interval_vector&& ivals, std::size_t depth = 16, std::size_t minbucket = 64, - K leftextent = 0, - K rightextent = 0, - std::size_t maxbucket = 512 - ) - : left(nullptr) - , right(nullptr) + std::size_t maxbucket = 512, + Scalar leftextent = 0, + Scalar rightextent = 0) + : left(nullptr) + , right(nullptr) { - --depth; - IntervalStartSorter intervalStartSorter; + const auto minmaxStop = std::minmax_element(ivals.begin(), ivals.end(), + IntervalStopCmp()); + const auto minmaxStart = std::minmax_element(ivals.begin(), ivals.end(), + IntervalStartCmp()); + if (!ivals.empty()) { + center = (minmaxStart.first->start + minmaxStop.second->stop) / 2; + } + if (leftextent == 0 && rightextent == 0) { + // sort intervals by start + std::sort(ivals.begin(), ivals.end(), IntervalStartCmp()); + } else { + assert(std::is_sorted(ivals.begin(), ivals.end(), IntervalStartCmp())); + } if (depth == 0 || (ivals.size() < minbucket && ivals.size() < maxbucket)) { - std::sort(ivals.begin(), ivals.end(), intervalStartSorter); - intervals = ivals; + std::sort(ivals.begin(), ivals.end(), IntervalStartCmp()); + intervals = std::move(ivals); + assert(is_valid().first); + return; } else { - if (leftextent == 0 && rightextent == 0) { - // sort intervals by start - std::sort(ivals.begin(), ivals.end(), intervalStartSorter); - } - - K leftp = 0; - K rightp = 0; - K centerp = 0; + Scalar leftp = 0; + Scalar rightp = 0; if (leftextent || rightextent) { leftp = leftextent; rightp = rightextent; } else { leftp = ivals.front().start; - std::vector stops; - stops.resize(ivals.size()); - transform(ivals.begin(), ivals.end(), stops.begin(), intervalStop); - rightp = *max_element(stops.begin(), stops.end()); + rightp = std::max_element(ivals.begin(), ivals.end(), + IntervalStopCmp())->stop; } - //centerp = ( leftp + rightp ) / 2; - centerp = ivals.at(ivals.size() / 2).start; - center = centerp; + interval_vector lefts; + interval_vector rights; - intervalVector lefts; - intervalVector rights; - - for (typename intervalVector::const_iterator i = ivals.begin(); i != ivals.end(); ++i) { + for (typename interval_vector::const_iterator i = ivals.begin(); + i != ivals.end(); ++i) { const interval& interval = *i; if (interval.stop < center) { lefts.push_back(interval); } else if (interval.start > center) { rights.push_back(interval); } else { + assert(interval.start <= center); + assert(center <= interval.stop); intervals.push_back(interval); } } if (!lefts.empty()) { - left = std::unique_ptr(new intervalTree(lefts, depth, minbucket, leftp, centerp)); + left.reset(new IntervalTree(std::move(lefts), + depth, minbucket, maxbucket, + leftp, center)); } if (!rights.empty()) { - right = std::unique_ptr(new intervalTree(rights, depth, minbucket, centerp, rightp)); + right.reset(new IntervalTree(std::move(rights), + depth, minbucket, maxbucket, + center, rightp)); } } + assert(is_valid().first); } - intervalVector findOverlapping(K start, K stop) const { - intervalVector ov; - this->findOverlapping(start, stop, ov); - return ov; - } - - void findOverlapping(K start, K stop, intervalVector& overlapping) const { + // Call f on all intervals near the range [start, stop]: + template + void visit_near(const Scalar& start, const Scalar& stop, UnaryFunction f) const { if (!intervals.empty() && ! (stop < intervals.front().start)) { - for (typename intervalVector::const_iterator i = intervals.begin(); i != intervals.end(); ++i) { - const interval& interval = *i; - if (interval.stop >= start && interval.start <= stop) { - overlapping.push_back(interval); - } + for (auto & i : intervals) { + f(i); } } - if (left && start <= center) { - left->findOverlapping(start, stop, overlapping); + left->visit_near(start, stop, f); } - if (right && stop >= center) { - right->findOverlapping(start, stop, overlapping); + right->visit_near(start, stop, f); } + } + // Call f on all intervals crossing pos + template + void visit_overlapping(const Scalar& pos, UnaryFunction f) const { + visit_overlapping(pos, pos, f); } - intervalVector findContained(K start, K stop) const { - intervalVector contained; - this->findContained(start, stop, contained); - return contained; + // Call f on all intervals overlapping [start, stop] + template + void visit_overlapping(const Scalar& start, const Scalar& stop, UnaryFunction f) const { + auto filterF = [&](const interval& interval) { + if (interval.stop >= start && interval.start <= stop) { + // Only apply f if overlapping + f(interval); + } + }; + visit_near(start, stop, filterF); } - void findContained(K start, K stop, intervalVector& contained) const { - if (!intervals.empty() && ! (stop < intervals.front().start)) { - for (typename intervalVector::const_iterator i = intervals.begin(); i != intervals.end(); ++i) { - const interval& interval = *i; - if (interval.start >= start && interval.stop <= stop) { - contained.push_back(interval); - } + // Call f on all intervals contained within [start, stop] + template + void visit_contained(const Scalar& start, const Scalar& stop, UnaryFunction f) const { + auto filterF = [&](const interval& interval) { + if (start <= interval.start && interval.stop <= stop) { + f(interval); } + }; + visit_near(start, stop, filterF); + } + + interval_vector findOverlapping(const Scalar& start, const Scalar& stop) const { + interval_vector result; + visit_overlapping(start, stop, + [&](const interval& interval) { + result.emplace_back(interval); + }); + return result; + } + + interval_vector findContained(const Scalar& start, const Scalar& stop) const { + interval_vector result; + visit_contained(start, stop, + [&](const interval& interval) { + result.push_back(interval); + }); + return result; + } + bool empty() const { + if (left && !left->empty()) { + return false; } + if (!intervals.empty()) { + return false; + } + if (right && !right->empty()) { + return false; + } + return true; + } - if (left && start <= center) { - left->findContained(start, stop, contained); + template + void visit_all(UnaryFunction f) const { + if (left) { + left->visit_all(f); } + std::for_each(intervals.begin(), intervals.end(), f); + if (right) { + right->visit_all(f); + } + } - if (right && stop >= center) { - right->findContained(start, stop, contained); + std::pair extentBruitForce() const { + struct Extent { + std::pair x = {std::numeric_limits::max(), + std::numeric_limits::min() }; + void operator()(const interval & interval) { + x.first = std::min(x.first, interval.start); + x.second = std::max(x.second, interval.stop); + } + }; + Extent extent; + + visit_all([&](const interval & interval) { extent(interval); }); + return extent.x; + } + + // Check all constraints. + // If first is false, second is invalid. + std::pair> is_valid() const { + const auto minmaxStop = std::minmax_element(intervals.begin(), intervals.end(), + IntervalStopCmp()); + const auto minmaxStart = std::minmax_element(intervals.begin(), intervals.end(), + IntervalStartCmp()); + + std::pair> result = {true, { std::numeric_limits::max(), + std::numeric_limits::min() }}; + if (!intervals.empty()) { + result.second.first = std::min(result.second.first, minmaxStart.first->start); + result.second.second = std::min(result.second.second, minmaxStop.second->stop); + } + if (left) { + auto valid = left->is_valid(); + result.first &= valid.first; + result.second.first = std::min(result.second.first, valid.second.first); + result.second.second = std::min(result.second.second, valid.second.second); + if (!result.first) { return result; } + if (valid.second.second >= center) { + result.first = false; + return result; + } } + if (right) { + auto valid = right->is_valid(); + result.first &= valid.first; + result.second.first = std::min(result.second.first, valid.second.first); + result.second.second = std::min(result.second.second, valid.second.second); + if (!result.first) { return result; } + if (valid.second.first <= center) { + result.first = false; + return result; + } + } + if (!std::is_sorted(intervals.begin(), intervals.end(), IntervalStartCmp())) { + result.first = false; + } + return result; + } + friend std::ostream& operator<<(std::ostream& os, const IntervalTree& itree) { + return writeOut(os, itree); } - ~IntervalTree(void) = default; + friend std::ostream& writeOut(std::ostream& os, const IntervalTree& itree, + std::size_t depth = 0) { + auto pad = [&]() { for (std::size_t i = 0; i != depth; ++i) { os << ' '; } }; + pad(); os << "center: " << itree.center << '\n'; + for (const interval & inter : itree.intervals) { + pad(); os << inter << '\n'; + } + if (itree.left) { + pad(); os << "left:\n"; + writeOut(os, *itree.left, depth + 1); + } else { + pad(); os << "left: nullptr\n"; + } + if (itree.right) { + pad(); os << "right:\n"; + writeOut(os, *itree.right, depth + 1); + } else { + pad(); os << "right: nullptr\n"; + } + return os; + } +private: + interval_vector intervals; + std::unique_ptr left; + std::unique_ptr right; + Scalar center; }; #endif diff --git a/interval_tree_test.cpp b/interval_tree_test.cpp index 7528af1..fbfee2f 100644 --- a/interval_tree_test.cpp +++ b/interval_tree_test.cpp @@ -10,18 +10,18 @@ using namespace std; -typedef Interval interval; -typedef vector intervalVector; -typedef IntervalTree intervalTree; +typedef IntervalTree intervalTree; +typedef intervalTree::interval interval; +typedef intervalTree::interval_vector intervalVector; TEST_CASE( "Empty tree" ) { - IntervalTree t; + IntervalTree t; REQUIRE( t.findOverlapping(-1,1).size() == 0 ); } TEST_CASE( "Singleton tree" ) { - vector> values{{1,3,5.5}}; - IntervalTree t{values}; + IntervalTree t{ {{1,3,5.5}}, + 1, 64, 512}; SECTION ("Point query on left") { auto v = t.findOverlapping(1,1); @@ -31,6 +31,39 @@ TEST_CASE( "Singleton tree" ) { REQUIRE( v.front().value == 5.5 ); } + SECTION ("Wild search values") { + typedef IntervalTree IT; + IT t { {{0.0, 1.0, 0}} }; + const auto inf = std::numeric_limits::infinity(); + const auto nan = std::numeric_limits::quiet_NaN(); + auto sanityResults = t.findOverlapping(inf, inf); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(-inf, inf); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(0, inf); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(0.5, inf); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(1.1, inf); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(-inf, 1.0); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(-inf, 0.5); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(-inf, 0.0); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(-inf, -0.1); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(nan, nan); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(-nan, nan); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(nan, 1); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(0, nan); + assert(sanityResults.size() == 0); + } + SECTION ("Point query in middle") { auto v = t.findOverlapping(2,2); REQUIRE( v.size() == 1); @@ -53,9 +86,8 @@ TEST_CASE( "Singleton tree" ) { } } -TEST_CASE( "Two identical intervals with different contents" ) { - vector> values{{5,10,10.5},{5,10,5.5}}; - IntervalTree t{values}; +TEST_CASE( "Two identical intervals with different contents" ) { + IntervalTree t{{{5,10,10.5},{5,10,5.5}}}; auto v = t.findOverlapping(6,6); REQUIRE( v.size() == 2); @@ -68,48 +100,50 @@ TEST_CASE( "Two identical intervals with different contents" ) { REQUIRE( actual == expected); } -template -K randKey(K floor, K ceiling) { - K range = ceiling - floor; +template +Scalar randKey(Scalar floor, Scalar ceiling) { + Scalar range = ceiling - floor; return floor + range * ((double) rand() / (double) (RAND_MAX + 1.0)); } -template -Interval randomInterval(K maxStart, K maxLength, K maxStop, const T& value) { - K start = randKey(0, maxStart); - K stop = min(randKey(start, start + maxLength), maxStop); - return Interval(start, stop, value); +template +Interval randomInterval(Scalar maxStart, Scalar maxLength, Scalar maxStop, + const Value& value) { + Scalar start = randKey(0, maxStart); + Scalar stop = min(randKey(start, start + maxLength), maxStop); + return Interval(start, stop, value); } int main(int argc, char**argv) { typedef vector countsVector; // a simple sanity check - intervalVector sanityIntervals; - sanityIntervals.push_back(interval(60, 80, true)); - sanityIntervals.push_back(interval(20, 40, true)); - intervalTree sanityTree(sanityIntervals); - - intervalVector sanityResults; - sanityTree.findOverlapping(30, 50, sanityResults); + typedef IntervalTree ITree; + ITree::interval_vector sanityIntervals; + sanityIntervals.push_back(ITree::interval(60, 80, true)); + sanityIntervals.push_back(ITree::interval(20, 40, true)); + ITree sanityTree(std::move(sanityIntervals), 16, 1); + + ITree::interval_vector sanityResults; + sanityResults = sanityTree.findOverlapping(30, 50); assert(sanityResults.size() == 1); - sanityResults.clear(); - sanityTree.findContained(15, 45, sanityResults); + + sanityResults = sanityTree.findContained(15, 45); assert(sanityResults.size() == 1); srand((unsigned)time(NULL)); - intervalVector intervals; - intervalVector queries; + ITree::interval_vector intervals; + ITree::interval_vector queries; // generate a test set of target intervals for (int i = 0; i < 10000; ++i) { - intervals.push_back(randomInterval(100000, 1000, 100000 + 1, true)); + intervals.push_back(randomInterval(100000, 1000, 100000 + 1, true)); } // and queries for (int i = 0; i < 5000; ++i) { - queries.push_back(randomInterval(100000, 1000, 100000 + 1, true)); + queries.push_back(randomInterval(100000, 1000, 100000 + 1, true)); } typedef chrono::high_resolution_clock Clock; @@ -118,9 +152,9 @@ int main(int argc, char**argv) { // using brute-force search countsVector bruteforcecounts; Clock::time_point t0 = Clock::now(); - for (intervalVector::iterator q = queries.begin(); q != queries.end(); ++q) { - intervalVector results; - for (intervalVector::iterator i = intervals.begin(); i != intervals.end(); ++i) { + for (auto q = queries.begin(); q != queries.end(); ++q) { + ITree::interval_vector results; + for (auto i = intervals.begin(); i != intervals.end(); ++i) { if (i->start >= q->start && i->stop <= q->stop) { results.push_back(*i); } @@ -132,12 +166,12 @@ int main(int argc, char**argv) { cout << "brute force:\t" << ms.count() << "ms" << endl; // using the interval tree - intervalTree tree = intervalTree(intervals); + cout << intervals[0]; + ITree tree = ITree(std::move(intervals), 16, 1); countsVector treecounts; t0 = Clock::now(); - for (intervalVector::iterator q = queries.begin(); q != queries.end(); ++q) { - intervalVector results; - tree.findContained(q->start, q->stop, results); + for (auto q = queries.begin(); q != queries.end(); ++q) { + auto results = tree.findContained(q->start, q->stop); treecounts.push_back(results.size()); } t1 = Clock::now();