From 0817591a921f5e4749b17f0e63dec95ec05945b6 Mon Sep 17 00:00:00 2001 From: Ben FrantzDale Date: Fri, 2 Mar 2018 08:15:51 -0500 Subject: [PATCH 1/6] Switch argument order so it's more like std::map, mapping scalar intervals to values. Remove default template argument. --- IntervalTree.h | 135 ++++++++++++++++++++--------------------- interval_tree_test.cpp | 48 ++++++++------- 2 files changed, 90 insertions(+), 93 deletions(-) diff --git a/IntervalTree.h b/IntervalTree.h index 6197a8a..40ee772 100644 --- a/IntervalTree.h +++ b/IntervalTree.h @@ -6,122 +6,118 @@ #include #include -template +template class Interval { public: - K start; - K stop; - T value; - Interval(K s, K e, const T& v) + Scalar start; + Scalar stop; + Value value; + Interval(const Scalar& s, const Scalar& e, const Value& v) : start(s) , stop(e) , value(v) { } }; -template -K intervalStart(const Interval& i) { +template +Value intervalStart(const Interval& i) { return i.start; } -template -K intervalStop(const Interval& i) { +template +Value intervalStop(const Interval& i) { return i.stop; } -template - std::ostream& operator<<(std::ostream& out, Interval& i) { +template + std::ostream& operator<<(std::ostream& out, Interval& i) { out << "Interval(" << i.start << ", " << i.stop << "): " << i.value; return out; } -template -class IntervalStartSorter { +class IntervalStartLessThan { public: - bool operator() (const Interval& a, const Interval& b) { + template + bool operator()(const Interval& a, const Interval& b) { return a.start < b.start; } }; -template +template >> class IntervalTree { - public: - typedef Interval interval; - typedef std::vector intervalVector; - typedef IntervalTree intervalTree; + typedef typename IntervalVector::value_type interval; + typedef IntervalVector interval_vector; - intervalVector intervals; - std::unique_ptr left; - std::unique_ptr right; - K center; + interval_vector intervals; + std::unique_ptr left; + std::unique_ptr right; + Scalar center; - IntervalTree(void) + IntervalTree() : left(nullptr) , right(nullptr) , center(0) - { } + {} -private: - std::unique_ptr copyTree(const intervalTree& orig){ - return std::unique_ptr(new intervalTree(orig)); + std::unique_ptr clone() const { + return std::unique_ptr(new IntervalTree(*this)); } -public: - IntervalTree(const intervalTree& other) + IntervalTree(const IntervalTree& other) : intervals(other.intervals), - left(other.left ? copyTree(*other.left) : nullptr), - right(other.right ? copyTree(*other.right) : nullptr), + left(other.left ? other.left->clone() : nullptr), + right(other.right ? other.right->clone() : nullptr), center(other.center) - { - } + {} -public: + IntervalTree& operator=(IntervalTree&&) = default; + IntervalTree(IntervalTree&&) = default; - IntervalTree& operator=(const intervalTree& other) { + IntervalTree& operator=(const IntervalTree& other) { center = other.center; intervals = other.intervals; - left = other.left ? copyTree(*other.left) : nullptr; - right = other.right ? copyTree(*other.right) : nullptr; + left = other.left ? other.left->clone() : nullptr; + right = other.right ? other.right->clone() : nullptr; return *this; } // Note: changes the order of ivals - IntervalTree( - intervalVector& ivals, + IntervalTree( + interval_vector& ivals, std::size_t depth = 16, std::size_t minbucket = 64, - K leftextent = 0, - K rightextent = 0, + Scalar leftextent = 0, + Scalar rightextent = 0, std::size_t maxbucket = 512 ) : left(nullptr) , right(nullptr) { - --depth; - IntervalStartSorter intervalStartSorter; + IntervalStartLessThan intervalStartLessThan; if (depth == 0 || (ivals.size() < minbucket && ivals.size() < maxbucket)) { - std::sort(ivals.begin(), ivals.end(), intervalStartSorter); + std::sort(ivals.begin(), ivals.end(), intervalStartLessThan); intervals = ivals; } else { if (leftextent == 0 && rightextent == 0) { // sort intervals by start - std::sort(ivals.begin(), ivals.end(), intervalStartSorter); + std::sort(ivals.begin(), ivals.end(), intervalStartLessThan); } - K leftp = 0; - K rightp = 0; - K centerp = 0; + Scalar leftp = 0; + Scalar rightp = 0; + Scalar centerp = 0; if (leftextent || rightextent) { leftp = leftextent; rightp = rightextent; } else { leftp = ivals.front().start; - std::vector stops; + std::vector stops; stops.resize(ivals.size()); - transform(ivals.begin(), ivals.end(), stops.begin(), intervalStop); + transform(ivals.begin(), ivals.end(), stops.begin(), intervalStop); rightp = *max_element(stops.begin(), stops.end()); } @@ -129,10 +125,10 @@ class IntervalTree { centerp = ivals.at(ivals.size() / 2).start; center = centerp; - intervalVector lefts; - intervalVector rights; + interval_vector lefts; + interval_vector rights; - for (typename intervalVector::const_iterator i = ivals.begin(); i != ivals.end(); ++i) { + for (typename interval_vector::const_iterator i = ivals.begin(); i != ivals.end(); ++i) { const interval& interval = *i; if (interval.stop < center) { lefts.push_back(interval); @@ -144,23 +140,23 @@ class IntervalTree { } if (!lefts.empty()) { - left = std::unique_ptr(new intervalTree(lefts, depth, minbucket, leftp, centerp)); + left.reset(new IntervalTree(lefts, depth, minbucket, leftp, centerp)); } if (!rights.empty()) { - right = std::unique_ptr(new intervalTree(rights, depth, minbucket, centerp, rightp)); + right.reset(new IntervalTree(rights, depth, minbucket, centerp, rightp)); } } } - intervalVector findOverlapping(K start, K stop) const { - intervalVector ov; - this->findOverlapping(start, stop, ov); - return ov; + interval_vector findOverlapping(const Scalar start, const Scalar& stop) const { + interval_vector ov; + findOverlapping(start, stop, ov); + return ov; } - void findOverlapping(K start, K stop, intervalVector& overlapping) const { + void findOverlapping(const Scalar& start, const Scalar& stop, interval_vector& overlapping) const { if (!intervals.empty() && ! (stop < intervals.front().start)) { - for (typename intervalVector::const_iterator i = intervals.begin(); i != intervals.end(); ++i) { + for (typename interval_vector::const_iterator i = intervals.begin(); i != intervals.end(); ++i) { const interval& interval = *i; if (interval.stop >= start && interval.start <= stop) { overlapping.push_back(interval); @@ -178,15 +174,15 @@ class IntervalTree { } - intervalVector findContained(K start, K stop) const { - intervalVector contained; - this->findContained(start, stop, contained); - return contained; + interval_vector findContained(const Scalar& start, const Scalar& stop) const { + interval_vector contained; + findContained(start, stop, contained); + return contained; } - void findContained(K start, K stop, intervalVector& contained) const { + void findContained(const Scalar& start, const Scalar& stop, interval_vector& contained) const { if (!intervals.empty() && ! (stop < intervals.front().start)) { - for (typename intervalVector::const_iterator i = intervals.begin(); i != intervals.end(); ++i) { + for (typename interval_vector::const_iterator i = intervals.begin(); i != intervals.end(); ++i) { const interval& interval = *i; if (interval.start >= start && interval.stop <= stop) { contained.push_back(interval); @@ -204,8 +200,7 @@ class IntervalTree { } - ~IntervalTree(void) = default; - + ~IntervalTree() = default; }; #endif diff --git a/interval_tree_test.cpp b/interval_tree_test.cpp index 7528af1..4bbdba5 100644 --- a/interval_tree_test.cpp +++ b/interval_tree_test.cpp @@ -10,18 +10,18 @@ using namespace std; -typedef Interval interval; -typedef vector intervalVector; -typedef IntervalTree intervalTree; +typedef IntervalTree intervalTree; +typedef intervalTree::interval interval; +typedef intervalTree::interval_vector intervalVector; TEST_CASE( "Empty tree" ) { - IntervalTree t; + IntervalTree t; REQUIRE( t.findOverlapping(-1,1).size() == 0 ); } TEST_CASE( "Singleton tree" ) { - vector> values{{1,3,5.5}}; - IntervalTree t{values}; + vector> values{{1,3,5.5}}; + IntervalTree t{values}; SECTION ("Point query on left") { auto v = t.findOverlapping(1,1); @@ -54,8 +54,8 @@ TEST_CASE( "Singleton tree" ) { } TEST_CASE( "Two identical intervals with different contents" ) { - vector> values{{5,10,10.5},{5,10,5.5}}; - IntervalTree t{values}; + vector> values{{5,10,10.5},{5,10,5.5}}; + IntervalTree t{values}; auto v = t.findOverlapping(6,6); REQUIRE( v.size() == 2); @@ -68,29 +68,31 @@ TEST_CASE( "Two identical intervals with different contents" ) { REQUIRE( actual == expected); } -template -K randKey(K floor, K ceiling) { - K range = ceiling - floor; +template +Scalar randKey(Scalar floor, Scalar ceiling) { + Scalar range = ceiling - floor; return floor + range * ((double) rand() / (double) (RAND_MAX + 1.0)); } -template -Interval randomInterval(K maxStart, K maxLength, K maxStop, const T& value) { - K start = randKey(0, maxStart); - K stop = min(randKey(start, start + maxLength), maxStop); - return Interval(start, stop, value); +template +Interval randomInterval(Scalar maxStart, Scalar maxLength, Scalar maxStop, + const Value& value) { + Scalar start = randKey(0, maxStart); + Scalar stop = min(randKey(start, start + maxLength), maxStop); + return Interval(start, stop, value); } int main(int argc, char**argv) { typedef vector countsVector; // a simple sanity check - intervalVector sanityIntervals; - sanityIntervals.push_back(interval(60, 80, true)); - sanityIntervals.push_back(interval(20, 40, true)); - intervalTree sanityTree(sanityIntervals); + typedef intervalTree ITree; + ITree::interval_vector sanityIntervals; + sanityIntervals.push_back(ITree::interval(60, 80, true)); + sanityIntervals.push_back(ITree::interval(20, 40, true)); + ITree sanityTree(sanityIntervals); - intervalVector sanityResults; + ITree::interval_vector sanityResults; sanityTree.findOverlapping(30, 50, sanityResults); assert(sanityResults.size() == 1); sanityResults.clear(); @@ -105,11 +107,11 @@ int main(int argc, char**argv) { // generate a test set of target intervals for (int i = 0; i < 10000; ++i) { - intervals.push_back(randomInterval(100000, 1000, 100000 + 1, true)); + intervals.push_back(randomInterval(100000, 1000, 100000 + 1, true)); } // and queries for (int i = 0; i < 5000; ++i) { - queries.push_back(randomInterval(100000, 1000, 100000 + 1, true)); + queries.push_back(randomInterval(100000, 1000, 100000 + 1, true)); } typedef chrono::high_resolution_clock Clock; From 53435db83206400a1dec0a6b9281fd543c3f7c01 Mon Sep 17 00:00:00 2001 From: Ben FrantzDale Date: Fri, 2 Mar 2018 08:32:38 -0500 Subject: [PATCH 2/6] Clean up construction so it takes an rvalue reference (rather than a non-const reference!) and makes use of the rvalue reference for efficency. --- IntervalTree.h | 43 ++++++++++++++++++++++++------------------ interval_tree_test.cpp | 12 +++++------- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/IntervalTree.h b/IntervalTree.h index 40ee772..7da77f9 100644 --- a/IntervalTree.h +++ b/IntervalTree.h @@ -35,14 +35,20 @@ template return out; } -class IntervalStartLessThan { -public: +struct IntervalStartLessThan { template bool operator()(const Interval& a, const Interval& b) { return a.start < b.start; } }; +struct IntervalStopLessThan { + template + bool operator()(const Interval& a, const Interval& b) { + return a.stop < b.stop; + } +}; + template >> class IntervalTree { @@ -83,27 +89,31 @@ class IntervalTree { return *this; } - // Note: changes the order of ivals IntervalTree( - interval_vector& ivals, + interval_vector&& ivals, std::size_t depth = 16, std::size_t minbucket = 64, Scalar leftextent = 0, Scalar rightextent = 0, - std::size_t maxbucket = 512 - ) + std::size_t maxbucket = 512) : left(nullptr) , right(nullptr) { --depth; - IntervalStartLessThan intervalStartLessThan; if (depth == 0 || (ivals.size() < minbucket && ivals.size() < maxbucket)) { - std::sort(ivals.begin(), ivals.end(), intervalStartLessThan); - intervals = ivals; + std::sort(ivals.begin(), ivals.end(), IntervalStartLessThan()); + intervals = std::move(ivals); + if (!intervals.empty()) { + const auto rightpoint = + std::max_element(intervals.begin(), intervals.end(), + IntervalStopLessThan())->stop; + center = (intervals.front().start + rightpoint) / 2; + } + return; } else { if (leftextent == 0 && rightextent == 0) { // sort intervals by start - std::sort(ivals.begin(), ivals.end(), intervalStartLessThan); + std::sort(ivals.begin(), ivals.end(), IntervalStartLessThan()); } Scalar leftp = 0; @@ -115,14 +125,11 @@ class IntervalTree { rightp = rightextent; } else { leftp = ivals.front().start; - std::vector stops; - stops.resize(ivals.size()); - transform(ivals.begin(), ivals.end(), stops.begin(), intervalStop); - rightp = *max_element(stops.begin(), stops.end()); + rightp = std::max_element(ivals.begin(), ivals.end(), + IntervalStopLessThan())->stop; } - //centerp = ( leftp + rightp ) / 2; - centerp = ivals.at(ivals.size() / 2).start; + centerp = ivals[ivals.size() / 2].start; center = centerp; interval_vector lefts; @@ -140,10 +147,10 @@ class IntervalTree { } if (!lefts.empty()) { - left.reset(new IntervalTree(lefts, depth, minbucket, leftp, centerp)); + left.reset(new IntervalTree(std::move(lefts), depth, minbucket, leftp, centerp)); } if (!rights.empty()) { - right.reset(new IntervalTree(rights, depth, minbucket, centerp, rightp)); + right.reset(new IntervalTree(std::move(rights), depth, minbucket, centerp, rightp)); } } } diff --git a/interval_tree_test.cpp b/interval_tree_test.cpp index 4bbdba5..4f921db 100644 --- a/interval_tree_test.cpp +++ b/interval_tree_test.cpp @@ -20,8 +20,7 @@ TEST_CASE( "Empty tree" ) { } TEST_CASE( "Singleton tree" ) { - vector> values{{1,3,5.5}}; - IntervalTree t{values}; + IntervalTree t{vector>{{1,3,5.5}}}; SECTION ("Point query on left") { auto v = t.findOverlapping(1,1); @@ -53,9 +52,8 @@ TEST_CASE( "Singleton tree" ) { } } -TEST_CASE( "Two identical intervals with different contents" ) { - vector> values{{5,10,10.5},{5,10,5.5}}; - IntervalTree t{values}; +TEST_CASE( "Two identical intervals with different contents" ) { + IntervalTree t{vector>{{5,10,10.5},{5,10,5.5}}}; auto v = t.findOverlapping(6,6); REQUIRE( v.size() == 2); @@ -90,7 +88,7 @@ int main(int argc, char**argv) { ITree::interval_vector sanityIntervals; sanityIntervals.push_back(ITree::interval(60, 80, true)); sanityIntervals.push_back(ITree::interval(20, 40, true)); - ITree sanityTree(sanityIntervals); + ITree sanityTree(std::move(sanityIntervals)); ITree::interval_vector sanityResults; sanityTree.findOverlapping(30, 50, sanityResults); @@ -134,7 +132,7 @@ int main(int argc, char**argv) { cout << "brute force:\t" << ms.count() << "ms" << endl; // using the interval tree - intervalTree tree = intervalTree(intervals); + intervalTree tree = intervalTree(std::move(intervals)); countsVector treecounts; t0 = Clock::now(); for (intervalVector::iterator q = queries.begin(); q != queries.end(); ++q) { From c5cd48c37297231e5d8da7594d9660e662845d13 Mon Sep 17 00:00:00 2001 From: Ben FrantzDale Date: Fri, 2 Mar 2018 13:17:25 -0500 Subject: [PATCH 3/6] Rearrange, add test code. --- IntervalTree.h | 285 +++++++++++++++++++++++++++++------------ interval_tree_test.cpp | 37 +++--- 2 files changed, 220 insertions(+), 102 deletions(-) diff --git a/IntervalTree.h b/IntervalTree.h index 7da77f9..540ce25 100644 --- a/IntervalTree.h +++ b/IntervalTree.h @@ -13,10 +13,10 @@ class Interval { Scalar stop; Value value; Interval(const Scalar& s, const Scalar& e, const Value& v) - : start(s) - , stop(e) - , value(v) - { } + : start(std::min(s, e)) + , stop(std::max(s, e)) + , value(v) + {} }; template @@ -30,36 +30,29 @@ Value intervalStop(const Interval& i) { } template - std::ostream& operator<<(std::ostream& out, Interval& i) { +std::ostream& operator<<(std::ostream& out, const Interval& i) { out << "Interval(" << i.start << ", " << i.stop << "): " << i.value; return out; } -struct IntervalStartLessThan { - template - bool operator()(const Interval& a, const Interval& b) { - return a.start < b.start; - } -}; - -struct IntervalStopLessThan { - template - bool operator()(const Interval& a, const Interval& b) { - return a.stop < b.stop; - } -}; - -template >> +template class IntervalTree { public: - typedef typename IntervalVector::value_type interval; - typedef IntervalVector interval_vector; + typedef Interval interval; + typedef std::vector interval_vector; - interval_vector intervals; - std::unique_ptr left; - std::unique_ptr right; - Scalar center; + + struct IntervalStartCmp { + bool operator()(const interval& a, const interval& b) { + return a.start < b.start; + } + }; + + struct IntervalStopCmp { + bool operator()(const interval& a, const interval& b) { + return a.stop < b.stop; + } + }; IntervalTree() : left(nullptr) @@ -67,6 +60,8 @@ class IntervalTree { , center(0) {} + ~IntervalTree() = default; + std::unique_ptr clone() const { return std::unique_ptr(new IntervalTree(*this)); } @@ -93,32 +88,34 @@ class IntervalTree { interval_vector&& ivals, std::size_t depth = 16, std::size_t minbucket = 64, + std::size_t maxbucket = 512, Scalar leftextent = 0, - Scalar rightextent = 0, - std::size_t maxbucket = 512) - : left(nullptr) - , right(nullptr) + Scalar rightextent = 0) + : left(nullptr) + , right(nullptr) { --depth; + const auto minmaxStop = std::minmax_element(ivals.begin(), ivals.end(), + IntervalStopCmp()); + const auto minmaxStart = std::minmax_element(ivals.begin(), ivals.end(), + IntervalStartCmp()); + if (!ivals.empty()) { + center = (minmaxStart.first->start + minmaxStop.second->stop) / 2; + } + if (leftextent == 0 && rightextent == 0) { + // sort intervals by start + std::sort(ivals.begin(), ivals.end(), IntervalStartCmp()); + } else { + assert(std::is_sorted(ivals.begin(), ivals.end(), IntervalStartCmp())); + } if (depth == 0 || (ivals.size() < minbucket && ivals.size() < maxbucket)) { - std::sort(ivals.begin(), ivals.end(), IntervalStartLessThan()); + std::sort(ivals.begin(), ivals.end(), IntervalStartCmp()); intervals = std::move(ivals); - if (!intervals.empty()) { - const auto rightpoint = - std::max_element(intervals.begin(), intervals.end(), - IntervalStopLessThan())->stop; - center = (intervals.front().start + rightpoint) / 2; - } + assert(is_valid().first); return; } else { - if (leftextent == 0 && rightextent == 0) { - // sort intervals by start - std::sort(ivals.begin(), ivals.end(), IntervalStartLessThan()); - } - Scalar leftp = 0; Scalar rightp = 0; - Scalar centerp = 0; if (leftextent || rightextent) { leftp = leftextent; @@ -126,88 +123,208 @@ class IntervalTree { } else { leftp = ivals.front().start; rightp = std::max_element(ivals.begin(), ivals.end(), - IntervalStopLessThan())->stop; + IntervalStopCmp())->stop; } - centerp = ivals[ivals.size() / 2].start; - center = centerp; - interval_vector lefts; interval_vector rights; - for (typename interval_vector::const_iterator i = ivals.begin(); i != ivals.end(); ++i) { + for (typename interval_vector::const_iterator i = ivals.begin(); + i != ivals.end(); ++i) { const interval& interval = *i; if (interval.stop < center) { lefts.push_back(interval); } else if (interval.start > center) { rights.push_back(interval); } else { + assert(interval.start <= center); + assert(center <= interval.stop); intervals.push_back(interval); } } if (!lefts.empty()) { - left.reset(new IntervalTree(std::move(lefts), depth, minbucket, leftp, centerp)); + left.reset(new IntervalTree(std::move(lefts), + depth, minbucket, maxbucket, + leftp, center)); } if (!rights.empty()) { - right.reset(new IntervalTree(std::move(rights), depth, minbucket, centerp, rightp)); + right.reset(new IntervalTree(std::move(rights), + depth, minbucket, maxbucket, + center, rightp)); } } + assert(is_valid().first); } - interval_vector findOverlapping(const Scalar start, const Scalar& stop) const { - interval_vector ov; - findOverlapping(start, stop, ov); - return ov; - } - - void findOverlapping(const Scalar& start, const Scalar& stop, interval_vector& overlapping) const { + // Call f on all intervals near the range [start, stop]: + template + void visit_near(const Scalar& start, const Scalar& stop, UnaryFunction f) const { if (!intervals.empty() && ! (stop < intervals.front().start)) { - for (typename interval_vector::const_iterator i = intervals.begin(); i != intervals.end(); ++i) { - const interval& interval = *i; - if (interval.stop >= start && interval.start <= stop) { - overlapping.push_back(interval); - } + for (auto & i : intervals) { + f(i); } } - if (left && start <= center) { - left->findOverlapping(start, stop, overlapping); + left->visit_near(start, stop, f); } - if (right && stop >= center) { - right->findOverlapping(start, stop, overlapping); + right->visit_near(start, stop, f); } - } - interval_vector findContained(const Scalar& start, const Scalar& stop) const { - interval_vector contained; - findContained(start, stop, contained); - return contained; + // Call f on all intervals overlapping [start, stop] + template + void visit_overlapping(const Scalar& start, const Scalar& stop, UnaryFunction f) const { + auto filterF = [&](const interval& interval) { + if (interval.stop >= start && interval.start <= stop) { + // Only apply f if overlapping + f(interval); + } + }; + visit_near(start, stop, filterF); } - void findContained(const Scalar& start, const Scalar& stop, interval_vector& contained) const { - if (!intervals.empty() && ! (stop < intervals.front().start)) { - for (typename interval_vector::const_iterator i = intervals.begin(); i != intervals.end(); ++i) { - const interval& interval = *i; - if (interval.start >= start && interval.stop <= stop) { - contained.push_back(interval); - } + // Call f on all intervals contained within [start, stop] + template + void visit_contained(const Scalar& start, const Scalar& stop, UnaryFunction f) const { + auto filterF = [&](const interval& interval) { + if (start <= interval.start && interval.stop <= stop) { + f(interval); } + }; + visit_near(start, stop, filterF); + } + + interval_vector findOverlapping(const Scalar& start, const Scalar& stop) const { + interval_vector result; + visit_overlapping(start, stop, + [&](const interval& interval) { + result.emplace_back(interval); + }); + return result; + } + + interval_vector findContained(const Scalar& start, const Scalar& stop) const { + interval_vector result; + visit_contained(start, stop, + [&](const interval& interval) { + result.push_back(interval); + }); + return result; + } + bool empty() const { + if (left && !left->empty()) { + return false; } + if (!intervals.empty()) { + return false; + } + if (right && !right->empty()) { + return false; + } + return true; + } - if (left && start <= center) { - left->findContained(start, stop, contained); + template + void visit_all(UnaryFunction f) const { + if (left) { + left->visit_all(f); + } + std::for_each(intervals.begin(), intervals.end(), f); + if (right) { + right->visit_all(f); } + } - if (right && stop >= center) { - right->findContained(start, stop, contained); + std::pair extentBruitForce() const { + struct Extent { + std::pair x = {std::numeric_limits::max(), + std::numeric_limits::min() }; + void operator()(const interval & interval) { + x.first = std::min(x.first, interval.start); + x.second = std::max(x.second, interval.stop); + } + }; + Extent extent; + + visit_all([&](const interval & interval) { extent(interval); }); + return extent.x; + } + + // Check all constraints. + // If first is false, second is invalid. + std::pair> is_valid() const { + const auto minmaxStop = std::minmax_element(intervals.begin(), intervals.end(), + IntervalStopCmp()); + const auto minmaxStart = std::minmax_element(intervals.begin(), intervals.end(), + IntervalStartCmp()); + + std::pair> result = {true, { std::numeric_limits::max(), + std::numeric_limits::min() }}; + if (!intervals.empty()) { + result.second.first = std::min(result.second.first, minmaxStart.first->start); + result.second.second = std::min(result.second.second, minmaxStop.second->stop); } + if (left) { + auto valid = left->is_valid(); + result.first &= valid.first; + result.second.first = std::min(result.second.first, valid.second.first); + result.second.second = std::min(result.second.second, valid.second.second); + if (!result.first) { return result; } + if (valid.second.second >= center) { + result.first = false; + return result; + } + } + if (right) { + auto valid = right->is_valid(); + result.first &= valid.first; + result.second.first = std::min(result.second.first, valid.second.first); + result.second.second = std::min(result.second.second, valid.second.second); + if (!result.first) { return result; } + if (valid.second.first <= center) { + result.first = false; + return result; + } + } + if (!std::is_sorted(intervals.begin(), intervals.end(), IntervalStartCmp())) { + result.first = false; + } + return result; + } + friend std::ostream& operator<<(std::ostream& os, const IntervalTree& itree) { + return writeOut(os, itree); } - ~IntervalTree() = default; + friend std::ostream& writeOut(std::ostream& os, const IntervalTree& itree, + std::size_t depth = 0) { + auto pad = [&]() { for (std::size_t i = 0; i != depth; ++i) { os << ' '; } }; + pad(); os << "center: " << itree.center << '\n'; + for (const interval & inter : itree.intervals) { + pad(); os << inter << '\n'; + } + if (itree.left) { + pad(); os << "left:\n"; + writeOut(os, *itree.left, depth + 1); + } else { + pad(); os << "left: nullptr\n"; + } + if (itree.right) { + pad(); os << "right:\n"; + writeOut(os, *itree.right, depth + 1); + } else { + pad(); os << "right: nullptr\n"; + } + return os; + } + +private: + interval_vector intervals; + std::unique_ptr left; + std::unique_ptr right; + Scalar center; }; #endif diff --git a/interval_tree_test.cpp b/interval_tree_test.cpp index 4f921db..e1b77f7 100644 --- a/interval_tree_test.cpp +++ b/interval_tree_test.cpp @@ -20,7 +20,8 @@ TEST_CASE( "Empty tree" ) { } TEST_CASE( "Singleton tree" ) { - IntervalTree t{vector>{{1,3,5.5}}}; + IntervalTree t{ {{1,3,5.5}}, + 1, 64, 512}; SECTION ("Point query on left") { auto v = t.findOverlapping(1,1); @@ -53,7 +54,7 @@ TEST_CASE( "Singleton tree" ) { } TEST_CASE( "Two identical intervals with different contents" ) { - IntervalTree t{vector>{{5,10,10.5},{5,10,5.5}}}; + IntervalTree t{{{5,10,10.5},{5,10,5.5}}}; auto v = t.findOverlapping(6,6); REQUIRE( v.size() == 2); @@ -84,32 +85,32 @@ int main(int argc, char**argv) { typedef vector countsVector; // a simple sanity check - typedef intervalTree ITree; + typedef IntervalTree ITree; ITree::interval_vector sanityIntervals; sanityIntervals.push_back(ITree::interval(60, 80, true)); sanityIntervals.push_back(ITree::interval(20, 40, true)); - ITree sanityTree(std::move(sanityIntervals)); + ITree sanityTree(std::move(sanityIntervals), 16, 1); ITree::interval_vector sanityResults; - sanityTree.findOverlapping(30, 50, sanityResults); + sanityResults = sanityTree.findOverlapping(30, 50); assert(sanityResults.size() == 1); - sanityResults.clear(); - sanityTree.findContained(15, 45, sanityResults); + + sanityResults = sanityTree.findContained(15, 45); assert(sanityResults.size() == 1); srand((unsigned)time(NULL)); - intervalVector intervals; - intervalVector queries; + ITree::interval_vector intervals; + ITree::interval_vector queries; // generate a test set of target intervals for (int i = 0; i < 10000; ++i) { - intervals.push_back(randomInterval(100000, 1000, 100000 + 1, true)); + intervals.push_back(randomInterval(100000, 1000, 100000 + 1, true)); } // and queries for (int i = 0; i < 5000; ++i) { - queries.push_back(randomInterval(100000, 1000, 100000 + 1, true)); + queries.push_back(randomInterval(100000, 1000, 100000 + 1, true)); } typedef chrono::high_resolution_clock Clock; @@ -118,9 +119,9 @@ int main(int argc, char**argv) { // using brute-force search countsVector bruteforcecounts; Clock::time_point t0 = Clock::now(); - for (intervalVector::iterator q = queries.begin(); q != queries.end(); ++q) { - intervalVector results; - for (intervalVector::iterator i = intervals.begin(); i != intervals.end(); ++i) { + for (auto q = queries.begin(); q != queries.end(); ++q) { + ITree::interval_vector results; + for (auto i = intervals.begin(); i != intervals.end(); ++i) { if (i->start >= q->start && i->stop <= q->stop) { results.push_back(*i); } @@ -132,12 +133,12 @@ int main(int argc, char**argv) { cout << "brute force:\t" << ms.count() << "ms" << endl; // using the interval tree - intervalTree tree = intervalTree(std::move(intervals)); + cout << intervals[0]; + ITree tree = ITree(std::move(intervals), 16, 1); countsVector treecounts; t0 = Clock::now(); - for (intervalVector::iterator q = queries.begin(); q != queries.end(); ++q) { - intervalVector results; - tree.findContained(q->start, q->stop, results); + for (auto q = queries.begin(); q != queries.end(); ++q) { + auto results = tree.findContained(q->start, q->stop); treecounts.push_back(results.size()); } t1 = Clock::now(); From 97ec36ecba39e1db4cca292bf5b8d954ca0f92ee Mon Sep 17 00:00:00 2001 From: Ben FrantzDale Date: Fri, 2 Mar 2018 13:20:41 -0500 Subject: [PATCH 4/6] Add visit_overlapping(Scalar, F) for a single point. --- IntervalTree.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/IntervalTree.h b/IntervalTree.h index 540ce25..9a680dd 100644 --- a/IntervalTree.h +++ b/IntervalTree.h @@ -173,6 +173,12 @@ class IntervalTree { } } + // Call f on all intervals crossing pos + template + void visit_overlapping(const Scalar& pos, UnaryFunction f) const { + visit_overlapping(pos, pos, f); + } + // Call f on all intervals overlapping [start, stop] template void visit_overlapping(const Scalar& start, const Scalar& stop, UnaryFunction f) const { From 209ad1146b61e354dd7f503708e5caf3a2d74f6b Mon Sep 17 00:00:00 2001 From: Ben FrantzDale Date: Fri, 2 Mar 2018 13:24:19 -0500 Subject: [PATCH 5/6] Include cassert. --- IntervalTree.h | 1 + 1 file changed, 1 insertion(+) diff --git a/IntervalTree.h b/IntervalTree.h index 9a680dd..1bf2c20 100644 --- a/IntervalTree.h +++ b/IntervalTree.h @@ -5,6 +5,7 @@ #include #include #include +#include template class Interval { From 8adb252846e17abdb3a432f9c367ec60a8744fd2 Mon Sep 17 00:00:00 2001 From: Ben FrantzDale Date: Fri, 2 Mar 2018 14:23:54 -0500 Subject: [PATCH 6/6] Add tests for searching including inf and nan. --- interval_tree_test.cpp | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/interval_tree_test.cpp b/interval_tree_test.cpp index e1b77f7..fbfee2f 100644 --- a/interval_tree_test.cpp +++ b/interval_tree_test.cpp @@ -31,6 +31,39 @@ TEST_CASE( "Singleton tree" ) { REQUIRE( v.front().value == 5.5 ); } + SECTION ("Wild search values") { + typedef IntervalTree IT; + IT t { {{0.0, 1.0, 0}} }; + const auto inf = std::numeric_limits::infinity(); + const auto nan = std::numeric_limits::quiet_NaN(); + auto sanityResults = t.findOverlapping(inf, inf); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(-inf, inf); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(0, inf); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(0.5, inf); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(1.1, inf); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(-inf, 1.0); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(-inf, 0.5); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(-inf, 0.0); + assert(sanityResults.size() == 1); + sanityResults = t.findOverlapping(-inf, -0.1); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(nan, nan); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(-nan, nan); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(nan, 1); + assert(sanityResults.size() == 0); + sanityResults = t.findOverlapping(0, nan); + assert(sanityResults.size() == 0); + } + SECTION ("Point query in middle") { auto v = t.findOverlapping(2,2); REQUIRE( v.size() == 1);