diff --git a/std/algorithm/comparison.d b/std/algorithm/comparison.d index a6414b0636c..7663a4d310c 100644 --- a/std/algorithm/comparison.d +++ b/std/algorithm/comparison.d @@ -580,43 +580,64 @@ do // cmp /********************************** -Performs three-way lexicographical comparison on two -$(REF_ALTTEXT input ranges, isInputRange, std,range,primitives) -according to predicate `pred`. Iterating `r1` and `r2` in -lockstep, `cmp` compares each element `e1` of `r1` with the -corresponding element `e2` in `r2`. If one of the ranges has been -finished, `cmp` returns a negative value if `r1` has fewer -elements than `r2`, a positive value if `r1` has more elements -than `r2`, and `0` if the ranges have the same number of -elements. +Performs a lexicographical comparison on two +$(REF_ALTTEXT input ranges, isInputRange, std,range,primitives). +Iterating `r1` and `r2` in lockstep, `cmp` compares each element +`e1` of `r1` with the corresponding element `e2` in `r2`. If one +of the ranges has been finished, `cmp` returns a negative value +if `r1` has fewer elements than `r2`, a positive value if `r1` +has more elements than `r2`, and `0` if the ranges have the same +number of elements. If the ranges are strings, `cmp` performs UTF decoding appropriately and compares the ranges one code point at a time. +A custom predicate may be specified, in which case `cmp` performs +a three-way lexicographical comparison using `pred`. Otherwise +the elements are compared using `opCmp`. + Params: - pred = The predicate used for comparison. + pred = Predicate used for comparison. Without a predicate + specified the ordering implied by `opCmp` is used. r1 = The first range. r2 = The second range. Returns: - 0 if both ranges compare equal. -1 if the first differing element of $(D - r1) is less than the corresponding element of `r2` according to $(D - pred). 1 if the first differing element of `r2` is less than the - corresponding element of `r1` according to `pred`. - + `0` if the ranges compare equal. A negative value if `r1` is a prefix of `r2` or + the first differing element of `r1` is less than the corresponding element of `r2` + according to `pred`. A positive value if `r2` is a prefix of `r1` or the first + differing element of `r2` is less than the corresponding element of `r1` + according to `pred`. + +Note: + An earlier version of the documentation incorrectly stated that `-1` is the + only negative value returned and `1` is the only positive value returned. + Whether that is true depends on the types being compared. */ -int cmp(alias pred = "a < b", R1, R2)(R1 r1, R2 r2) +auto cmp(R1, R2)(R1 r1, R2 r2) if (isInputRange!R1 && isInputRange!R2) { static if (!(isSomeString!R1 && isSomeString!R2)) { for (;; r1.popFront(), r2.popFront()) { - if (r1.empty) return -cast(int)!r2.empty; - if (r2.empty) return !r1.empty; - auto a = r1.front, b = r2.front; - if (binaryFun!pred(a, b)) return -1; - if (binaryFun!pred(b, a)) return 1; + static if (is(typeof(r1.front.opCmp(r2.front)) R)) + alias Result = R; + else + alias Result = int; + if (r2.empty) return Result(!r1.empty); + if (r1.empty) return Result(-1); + static if (is(typeof(r1.front.opCmp(r2.front)))) + { + auto c = r1.front.opCmp(r2.front); + if (c != 0) return c; + } + else + { + auto a = r1.front, b = r2.front; + if (a < b) return -1; + if (b < a) return 1; + } } } else @@ -624,35 +645,28 @@ if (isInputRange!R1 && isInputRange!R2) import core.stdc.string : memcmp; import std.utf : decode; - static if (is(typeof(pred) : string)) - enum isLessThan = pred == "a < b"; - else - enum isLessThan = false; - // For speed only static int threeWay(size_t a, size_t b) { - static if (size_t.sizeof == int.sizeof && isLessThan) + static if (size_t.sizeof == int.sizeof) return a - b; else - return binaryFun!pred(b, a) ? 1 : binaryFun!pred(a, b) ? -1 : 0; + // Faster than return b < a ? 1 : a < b ? -1 : 0; + return (a > b) - (a < b); } // For speed only // @@@BUG@@@ overloading should be allowed for nested functions static int threeWayInt(int a, int b) { - static if (isLessThan) - return a - b; - else - return binaryFun!pred(b, a) ? 1 : binaryFun!pred(a, b) ? -1 : 0; + return a - b; } - static if (typeof(r1[0]).sizeof == typeof(r2[0]).sizeof && isLessThan) + static if (typeof(r1[0]).sizeof == typeof(r2[0]).sizeof) { static if (typeof(r1[0]).sizeof == 1) { immutable len = min(r1.length, r2.length); - immutable result = __ctfe ? + int result = __ctfe ? { foreach (i; 0 .. len) { @@ -663,17 +677,21 @@ if (isInputRange!R1 && isInputRange!R2) }() : () @trusted { return memcmp(r1.ptr, r2.ptr, len); }(); if (result) return result; + return threeWay(r1.length, r2.length); } else { - auto p1 = r1.ptr, p2 = r2.ptr, - pEnd = p1 + min(r1.length, r2.length); - for (; p1 != pEnd; ++p1, ++p2) + return () @trusted { - if (*p1 != *p2) return threeWayInt(cast(int) *p1, cast(int) *p2); - } + auto p1 = r1.ptr, p2 = r2.ptr, + pEnd = p1 + min(r1.length, r2.length); + for (; p1 != pEnd; ++p1, ++p2) + { + if (*p1 != *p2) return threeWayInt(int(*p1), int(*p2)); + } + return threeWay(r1.length, r2.length); + }(); } - return threeWay(r1.length, r2.length); } else { @@ -683,14 +701,58 @@ if (isInputRange!R1 && isInputRange!R2) if (i2 == r2.length) return threeWay(r1.length, i1); immutable c1 = decode(r1, i1), c2 = decode(r2, i2); - if (c1 != c2) return threeWayInt(cast(int) c1, cast(int) c2); + if (c1 != c2) return threeWayInt(int(c1), int(c2)); + } + } + } +} + +/// ditto +int cmp(alias pred, R1, R2)(R1 r1, R2 r2) +if (isInputRange!R1 && isInputRange!R2) +{ + static if (!(isSomeString!R1 && isSomeString!R2)) + { + for (;; r1.popFront(), r2.popFront()) + { + if (r2.empty) return !r1.empty; + if (r1.empty) return -1; + auto a = r1.front, b = r2.front; + if (binaryFun!pred(a, b)) return -1; + if (binaryFun!pred(b, a)) return 1; + } + } + else + { + import std.utf : decode; + + // For speed only + static int threeWayCompareLength(size_t a, size_t b) + { + static if (size_t.sizeof == int.sizeof) + return a - b; + else + // Faster than return b < a ? 1 : a < b ? -1 : 0; + return (a > b) - (a < b); + } + + for (size_t i1, i2;;) + { + if (i1 == r1.length) return threeWayCompareLength(i2, r2.length); + if (i2 == r2.length) return threeWayCompareLength(r1.length, i1); + immutable c1 = decode(r1, i1), + c2 = decode(r2, i2); + if (c1 != c2) + { + if (binaryFun!pred(c2, c1)) return 1; + if (binaryFun!pred(c1, c2)) return -1; } } } } /// -@safe unittest +pure @safe unittest { int result; @@ -712,6 +774,8 @@ if (isInputRange!R1 && isInputRange!R2) assert(result > 0); result = cmp("aaa", "aaa"d); assert(result == 0); + result = cmp("aaa"d, "aaa"d); + assert(result == 0); result = cmp(cast(int[])[], cast(int[])[]); assert(result == 0); result = cmp([1, 2, 3], [1, 2, 3]); @@ -724,6 +788,106 @@ if (isInputRange!R1 && isInputRange!R2) assert(result > 0); } +/// Example predicate that compares individual elements in reverse lexical order +pure @safe unittest +{ + int result; + + result = cmp!"a > b"("abc", "abc"); + assert(result == 0); + result = cmp!"a > b"("", ""); + assert(result == 0); + result = cmp!"a > b"("abc", "abcd"); + assert(result < 0); + result = cmp!"a > b"("abcd", "abc"); + assert(result > 0); + result = cmp!"a > b"("abc"d, "abd"); + assert(result > 0); + result = cmp!"a > b"("bbc", "abc"w); + assert(result < 0); + result = cmp!"a > b"("aaa", "aaaa"d); + assert(result < 0); + result = cmp!"a > b"("aaaa", "aaa"d); + assert(result > 0); + result = cmp!"a > b"("aaa", "aaa"d); + assert(result == 0); + result = cmp("aaa"d, "aaa"d); + assert(result == 0); + result = cmp!"a > b"(cast(int[])[], cast(int[])[]); + assert(result == 0); + result = cmp!"a > b"([1, 2, 3], [1, 2, 3]); + assert(result == 0); + result = cmp!"a > b"([1, 3, 2], [1, 2, 3]); + assert(result < 0); + result = cmp!"a > b"([1, 2, 3], [1L, 2, 3, 4]); + assert(result < 0); + result = cmp!"a > b"([1L, 2, 3], [1, 2]); + assert(result > 0); +} + +@nogc nothrow pure @safe unittest +{ + // Issue 18286: cmp for string with custom predicate fails if distinct chars can compare equal + static bool ltCi(dchar a, dchar b)// less than, case insensitive + { + import std.ascii : toUpper; + return toUpper(a) < toUpper(b); + } + static assert(cmp!ltCi("apple2", "APPLE1") > 0); + static assert(cmp!ltCi("apple1", "APPLE2") < 0); + static assert(cmp!ltCi("apple", "APPLE1") < 0); + static assert(cmp!ltCi("APPLE", "apple1") < 0); + static assert(cmp!ltCi("apple", "APPLE") == 0); +} + +@nogc nothrow @safe unittest +{ + // Issue 18280: for non-string ranges check that opCmp is evaluated only once per pair. + static int ctr = 0; + struct S + { + int opCmp(ref const S rhs) const + { + ++ctr; + return 0; + } + } + immutable S[4] a; + immutable S[4] b; + immutable result = cmp(a[], b[]); + assert(result == 0, "neither should compare greater than the other!"); + assert(ctr == a.length, "opCmp should be called exactly once per pair of items!"); +} + +nothrow pure @safe unittest +{ + // Test cmp when opCmp returns float. + struct F + { + float value; + float opCmp(const ref F rhs) const + { + return value - rhs.value; + } + } + auto result = cmp([F(1), F(2), F(3)], [F(1), F(2), F(3)]); + assert(result == 0); + assert(is(typeof(result) == float)); + result = cmp([F(1), F(3), F(2)], [F(1), F(2), F(3)]); + assert(result > 0); + result = cmp([F(1), F(2), F(3)], [F(1), F(2), F(3), F(4)]); + assert(result < 0); + result = cmp([F(1), F(2), F(3)], [F(1), F(2)]); + assert(result > 0); +} + +nothrow pure @safe unittest +{ + // Parallelism (was broken by inferred return type "immutable int") + import std.parallelism : task; + auto t = task!cmp("foo", "bar"); +} + // equal /** Compares two ranges for equality, as defined by predicate `pred`