Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 206 additions & 42 deletions std/algorithm/comparison.d
Original file line number Diff line number Diff line change
Expand Up @@ -580,79 +580,93 @@ do

// cmp
/**********************************
Performs three-way lexicographical comparison on two
$(REF_ALTTEXT input ranges, isInputRange, std,range,primitives)
according to predicate `pred`. Iterating `r1` and `r2` in
lockstep, `cmp` compares each element `e1` of `r1` with the
corresponding element `e2` in `r2`. If one of the ranges has been
finished, `cmp` returns a negative value if `r1` has fewer
elements than `r2`, a positive value if `r1` has more elements
than `r2`, and `0` if the ranges have the same number of
elements.
Performs a lexicographical comparison on two
$(REF_ALTTEXT input ranges, isInputRange, std,range,primitives).
Iterating `r1` and `r2` in lockstep, `cmp` compares each element
`e1` of `r1` with the corresponding element `e2` in `r2`. If one
of the ranges has been finished, `cmp` returns a negative value
if `r1` has fewer elements than `r2`, a positive value if `r1`
has more elements than `r2`, and `0` if the ranges have the same
number of elements.

If the ranges are strings, `cmp` performs UTF decoding
appropriately and compares the ranges one code point at a time.

A custom predicate may be specified, in which case `cmp` performs
a three-way lexicographical comparison using `pred`. Otherwise
the elements are compared using `opCmp`.

Params:
pred = The predicate used for comparison.
pred = Predicate used for comparison. Without a predicate
specified the ordering implied by `opCmp` is used.
r1 = The first range.
r2 = The second range.

Returns:
0 if both ranges compare equal. -1 if the first differing element of $(D
r1) is less than the corresponding element of `r2` according to $(D
pred). 1 if the first differing element of `r2` is less than the
corresponding element of `r1` according to `pred`.

`0` if the ranges compare equal. A negative value if `r1` is a prefix of `r2` or
the first differing element of `r1` is less than the corresponding element of `r2`
according to `pred`. A positive value if `r2` is a prefix of `r1` or the first
differing element of `r2` is less than the corresponding element of `r1`
according to `pred`.

Note:
An earlier version of the documentation incorrectly stated that `-1` is the
only negative value returned and `1` is the only positive value returned.
Whether that is true depends on the types being compared.
*/
int cmp(alias pred = "a < b", R1, R2)(R1 r1, R2 r2)
auto cmp(R1, R2)(R1 r1, R2 r2)
if (isInputRange!R1 && isInputRange!R2)
{
static if (!(isSomeString!R1 && isSomeString!R2))
{
for (;; r1.popFront(), r2.popFront())
{
if (r1.empty) return -cast(int)!r2.empty;
if (r2.empty) return !r1.empty;
auto a = r1.front, b = r2.front;
if (binaryFun!pred(a, b)) return -1;
if (binaryFun!pred(b, a)) return 1;
static if (is(typeof(r1.front.opCmp(r2.front)) R))
alias Result = R;
else
alias Result = int;
if (r2.empty) return Result(!r1.empty);
if (r1.empty) return Result(-1);
static if (is(typeof(r1.front.opCmp(r2.front))))
{
auto c = r1.front.opCmp(r2.front);
if (c != 0) return c;
}
else
{
auto a = r1.front, b = r2.front;
if (a < b) return -1;
if (b < a) return 1;
}
}
}
else
{
import core.stdc.string : memcmp;
import std.utf : decode;

static if (is(typeof(pred) : string))
enum isLessThan = pred == "a < b";
else
enum isLessThan = false;

// For speed only
static int threeWay(size_t a, size_t b)
{
static if (size_t.sizeof == int.sizeof && isLessThan)
static if (size_t.sizeof == int.sizeof)
return a - b;
else
return binaryFun!pred(b, a) ? 1 : binaryFun!pred(a, b) ? -1 : 0;
// Faster than return b < a ? 1 : a < b ? -1 : 0;
return (a > b) - (a < b);
}
// For speed only
// @@@BUG@@@ overloading should be allowed for nested functions
static int threeWayInt(int a, int b)
{
static if (isLessThan)
return a - b;
else
return binaryFun!pred(b, a) ? 1 : binaryFun!pred(a, b) ? -1 : 0;
return a - b;
}

static if (typeof(r1[0]).sizeof == typeof(r2[0]).sizeof && isLessThan)
static if (typeof(r1[0]).sizeof == typeof(r2[0]).sizeof)
{
static if (typeof(r1[0]).sizeof == 1)
{
immutable len = min(r1.length, r2.length);
immutable result = __ctfe ?
int result = __ctfe ?
{
foreach (i; 0 .. len)
{
Expand All @@ -663,17 +677,21 @@ if (isInputRange!R1 && isInputRange!R2)
}()
: () @trusted { return memcmp(r1.ptr, r2.ptr, len); }();
if (result) return result;
return threeWay(r1.length, r2.length);
}
else
{
auto p1 = r1.ptr, p2 = r2.ptr,
pEnd = p1 + min(r1.length, r2.length);
for (; p1 != pEnd; ++p1, ++p2)
return () @trusted
{
if (*p1 != *p2) return threeWayInt(cast(int) *p1, cast(int) *p2);
}
auto p1 = r1.ptr, p2 = r2.ptr,
pEnd = p1 + min(r1.length, r2.length);
for (; p1 != pEnd; ++p1, ++p2)
{
if (*p1 != *p2) return threeWayInt(int(*p1), int(*p2));
}
return threeWay(r1.length, r2.length);
}();
}
return threeWay(r1.length, r2.length);
}
else
{
Expand All @@ -683,14 +701,58 @@ if (isInputRange!R1 && isInputRange!R2)
if (i2 == r2.length) return threeWay(r1.length, i1);
immutable c1 = decode(r1, i1),
c2 = decode(r2, i2);
if (c1 != c2) return threeWayInt(cast(int) c1, cast(int) c2);
if (c1 != c2) return threeWayInt(int(c1), int(c2));
}
}
}
}

/// ditto
int cmp(alias pred, R1, R2)(R1 r1, R2 r2)
if (isInputRange!R1 && isInputRange!R2)
{
static if (!(isSomeString!R1 && isSomeString!R2))
{
for (;; r1.popFront(), r2.popFront())
{
if (r2.empty) return !r1.empty;
if (r1.empty) return -1;
auto a = r1.front, b = r2.front;
if (binaryFun!pred(a, b)) return -1;
if (binaryFun!pred(b, a)) return 1;
}
}
else
{
import std.utf : decode;

// For speed only
static int threeWayCompareLength(size_t a, size_t b)
{
static if (size_t.sizeof == int.sizeof)
return a - b;
else
// Faster than return b < a ? 1 : a < b ? -1 : 0;
return (a > b) - (a < b);
}

for (size_t i1, i2;;)
{
if (i1 == r1.length) return threeWayCompareLength(i2, r2.length);
if (i2 == r2.length) return threeWayCompareLength(r1.length, i1);
immutable c1 = decode(r1, i1),
c2 = decode(r2, i2);
if (c1 != c2)
{
if (binaryFun!pred(c2, c1)) return 1;
if (binaryFun!pred(c1, c2)) return -1;
}
}
}
}

///
@safe unittest
pure @safe unittest
{
int result;

Expand All @@ -712,6 +774,8 @@ if (isInputRange!R1 && isInputRange!R2)
assert(result > 0);
result = cmp("aaa", "aaa"d);
assert(result == 0);
result = cmp("aaa"d, "aaa"d);
assert(result == 0);
result = cmp(cast(int[])[], cast(int[])[]);
assert(result == 0);
result = cmp([1, 2, 3], [1, 2, 3]);
Expand All @@ -724,6 +788,106 @@ if (isInputRange!R1 && isInputRange!R2)
assert(result > 0);
}

/// Example predicate that compares individual elements in reverse lexical order
pure @safe unittest
{
int result;

result = cmp!"a > b"("abc", "abc");
assert(result == 0);
result = cmp!"a > b"("", "");
assert(result == 0);
result = cmp!"a > b"("abc", "abcd");
assert(result < 0);
result = cmp!"a > b"("abcd", "abc");
assert(result > 0);
result = cmp!"a > b"("abc"d, "abd");
assert(result > 0);
result = cmp!"a > b"("bbc", "abc"w);
assert(result < 0);
result = cmp!"a > b"("aaa", "aaaa"d);
assert(result < 0);
result = cmp!"a > b"("aaaa", "aaa"d);
assert(result > 0);
result = cmp!"a > b"("aaa", "aaa"d);
assert(result == 0);
result = cmp("aaa"d, "aaa"d);
assert(result == 0);
result = cmp!"a > b"(cast(int[])[], cast(int[])[]);
assert(result == 0);
result = cmp!"a > b"([1, 2, 3], [1, 2, 3]);
assert(result == 0);
result = cmp!"a > b"([1, 3, 2], [1, 2, 3]);
assert(result < 0);
result = cmp!"a > b"([1, 2, 3], [1L, 2, 3, 4]);
assert(result < 0);
result = cmp!"a > b"([1L, 2, 3], [1, 2]);
assert(result > 0);
}

@nogc nothrow pure @safe unittest
{
// Issue 18286: cmp for string with custom predicate fails if distinct chars can compare equal
static bool ltCi(dchar a, dchar b)// less than, case insensitive
{
import std.ascii : toUpper;
return toUpper(a) < toUpper(b);
}
static assert(cmp!ltCi("apple2", "APPLE1") > 0);
static assert(cmp!ltCi("apple1", "APPLE2") < 0);
static assert(cmp!ltCi("apple", "APPLE1") < 0);
static assert(cmp!ltCi("APPLE", "apple1") < 0);
static assert(cmp!ltCi("apple", "APPLE") == 0);
}

@nogc nothrow @safe unittest
{
// Issue 18280: for non-string ranges check that opCmp is evaluated only once per pair.
static int ctr = 0;
struct S
{
int opCmp(ref const S rhs) const
{
++ctr;
return 0;
}
}
immutable S[4] a;
immutable S[4] b;
immutable result = cmp(a[], b[]);
assert(result == 0, "neither should compare greater than the other!");
assert(ctr == a.length, "opCmp should be called exactly once per pair of items!");
}

nothrow pure @safe unittest
{
// Test cmp when opCmp returns float.
struct F
{
float value;
float opCmp(const ref F rhs) const
{
return value - rhs.value;
}
}
auto result = cmp([F(1), F(2), F(3)], [F(1), F(2), F(3)]);
assert(result == 0);
assert(is(typeof(result) == float));
result = cmp([F(1), F(3), F(2)], [F(1), F(2), F(3)]);
assert(result > 0);
result = cmp([F(1), F(2), F(3)], [F(1), F(2), F(3), F(4)]);
assert(result < 0);
result = cmp([F(1), F(2), F(3)], [F(1), F(2)]);
assert(result > 0);
}

nothrow pure @safe unittest
{
// Parallelism (was broken by inferred return type "immutable int")
import std.parallelism : task;
auto t = task!cmp("foo", "bar");
}

// equal
/**
Compares two ranges for equality, as defined by predicate `pred`
Expand Down