Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 103 additions & 25 deletions src/cpp/flann/algorithms/dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ typedef unsigned __int64 uint64_t;

#include "flann/defines.h"

#ifdef __SSE2__
#include <xmmintrin.h>
#endif


namespace flann
{
Expand Down Expand Up @@ -137,6 +141,11 @@ struct L2
typedef T ElementType;
typedef typename Accumulator<T>::Type ResultType;

template <typename Iterator>
struct ConstIterator { using type = Iterator; };
template <typename U>
struct ConstIterator<U*> { using type = const U*; };

/**
* Compute the squared Euclidean distance between two vectors.
*
Expand All @@ -149,31 +158,13 @@ struct L2
template <typename Iterator1, typename Iterator2>
ResultType operator()(Iterator1 a, Iterator2 b, size_t size, ResultType worst_dist = -1) const
{
ResultType result = ResultType();
ResultType diff0, diff1, diff2, diff3;
Iterator1 last = a + size;
Iterator1 lastgroup = last - 3;

/* Process 4 items with each loop for efficiency. */
while (a < lastgroup) {
diff0 = (ResultType)(a[0] - b[0]);
diff1 = (ResultType)(a[1] - b[1]);
diff2 = (ResultType)(a[2] - b[2]);
diff3 = (ResultType)(a[3] - b[3]);
result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
a += 4;
b += 4;

if ((worst_dist>0)&&(result>worst_dist)) {
return result;
}
}
/* Process last 0-3 pixels. Not needed for standard vector lengths. */
while (a < last) {
diff0 = (ResultType)(*a++ - *b++);
result += diff0 * diff0;
}
return result;
/* We only ever read from a and b, so we pass const versions to `Compute`.
* This ensures that in the case of pointers, the const* version is
* selected, and avoids having to write const and non-const overloads of
* VectorizedLoop. */
return Compute(static_cast<typename ConstIterator<Iterator1>::type>(a),
static_cast<typename ConstIterator<Iterator2>::type>(b),
size, worst_dist);
}

/**
Expand All @@ -187,6 +178,93 @@ struct L2
{
return (a-b)*(a-b);
}

private:
static_assert(std::is_same<typename ConstIterator<float*>::type, const float*>::value, "");
static_assert(std::is_same<typename ConstIterator<const float*>::type, const float*>::value, "");

template <typename ConstIterator1, typename ConstIterator2>
ResultType Compute(ConstIterator1 a, ConstIterator2 b, size_t size, ResultType worst_dist) const
{
ConstIterator1 last = a + size;
ResultType result = ResultType();
/* Process several pixels at a time. */
if (worst_dist>0) {
if (VectorizedLoop(a, last, b, result, worst_dist)) {
return result;
}
} else {
VectorizedLoop(a, last, b, result);
}

/* Process last pixels. Not needed for standard vector lengths. */
while (a < last) {
ResultType diff0 = (ResultType)(*a++ - *b++);
result += diff0 * diff0;
}
return result;
}

/* Default loop implementation.. */
template <typename ConstIterator1, typename ConstIterator2>
static inline bool VectorizedLoop(ConstIterator1& a, ConstIterator1 last, ConstIterator2& b, ResultType& result, ResultType worst_dist = 0) {
ConstIterator1 lastgroup = last - 3;
/* Process 4 items with each loop for efficiency. */
while (a < lastgroup) {
ResultType diff0 = (ResultType)(a[0] - b[0]);
ResultType diff1 = (ResultType)(a[1] - b[1]);
ResultType diff2 = (ResultType)(a[2] - b[2]);
ResultType diff3 = (ResultType)(a[3] - b[3]);
result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
a += 4;
b += 4;

if ((worst_dist>0) &&(result>worst_dist)) {
return true;
}
}
return false;
};

#ifdef __SSE2__
/* A more efficient loop for (const float*, const float*) -> float. */
static inline bool VectorizedLoop(const float*& a, const float* last, const float*& b, float& result, float worst_dist) {
const float* const lastgroup = last - 3;
/* Process 4 items in parallel. */
/* When a worst_dist is provided, we have to reduce at every iteration
* to check*/
while (a < lastgroup) {
const __m128 diff = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
const __m128 sqr = _mm_mul_ps(diff, diff);
float elements[4];
memcpy(elements, &sqr, sizeof(__m128));
result+= elements[3] + elements[2] + elements[1] + elements[0];
a += 4;
b += 4;

if ((worst_dist>0)&&(result>worst_dist)) {
return true;
}
}
return false;
};
static inline void VectorizedLoop(const float*& a, const float* last, const float*& b, float& result) {
const float* const lastgroup = last - 3;
/* Process 4 items in parallel. */
__m128 v_result = _mm_set1_ps(0.0f);
while (a < lastgroup) {
const __m128 diff = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
const __m128 sqr = _mm_mul_ps(diff, diff);
v_result = _mm_add_ps(v_result, sqr);
a += 4;
b += 4;
}
float elements[4];
memcpy(elements, &v_result, sizeof(__m128));
result = elements[0] + elements[1] + elements[2] + elements[3];
};
#endif

};


Expand Down