diff --git a/std/algorithm/searching.d b/std/algorithm/searching.d index 2767c8145bb..0ab5871a62f 100644 --- a/std/algorithm/searching.d +++ b/std/algorithm/searching.d @@ -1245,45 +1245,90 @@ private auto extremum(alias map = "a", alias selector = "a < b", Range, if (isInputRange!Range && !isInfinite!Range && !is(CommonType!(ElementType!Range, RangeElementType) == void)) { - alias mapFun = unaryFun!map; - alias selectorFun = binaryFun!selector; + enum isMappingFirst = __traits(compiles, unaryFun!map(seedElement)); + + // check for identity ("a") + static if (isSomeString!(typeof(map))) + enum isIdentity = map == "a"; + else + enum isIdentity = false; + + // shorthand: if a binary function is given, it is the selector + static if (isMappingFirst || isIdentity) + { + alias selectorFun = binaryFun!selector; + } + else + { + alias selectorFun = binaryFun!map; + } alias Element = ElementType!Range; alias CommonElement = CommonType!(Element, RangeElementType); - alias MapType = Unqual!(typeof(mapFun(CommonElement.init))); - Unqual!CommonElement extremeElement = seedElement; - MapType extremeElementMapped = mapFun(extremeElement); - static if (isRandomAccessRange!Range && hasLength!Range) + // - direct access via a random access range is faster + // - if we only have one statement in the loop it can be optimized a lot better + static if (isIdentity || !isMappingFirst) { - foreach (const i; 0 .. r.length) + static if (isRandomAccessRange!Range && hasLength!Range) { - MapType mapElement = mapFun(r[i]); - if (selectorFun(mapElement, extremeElementMapped)) + foreach (const i; 0 .. r.length) { - extremeElement = r[i]; - extremeElementMapped = mapElement; + if (selectorFun(r[i], extremeElement)) + { + extremeElement = r[i]; + } + } + } + else + { + while (!r.empty) + { + if (selectorFun(r.front, extremeElement)) + { + extremeElement = r.front; + } + r.popFront(); } } } else { - while (!r.empty) + alias mapFun = unaryFun!map; + + alias MapType = Unqual!(typeof(mapFun(CommonElement.init))); + MapType extremeElementMapped = mapFun(extremeElement); + static if (isRandomAccessRange!Range && hasLength!Range) { - MapType mapElement = mapFun(r.front); - if (selectorFun(mapElement, extremeElementMapped)) + foreach (const i; 0 .. r.length) { - extremeElement = r.front; - extremeElementMapped = mapElement; + MapType mapElement = mapFun(r[i]); + if (selectorFun(mapElement, extremeElementMapped)) + { + extremeElement = r[i]; + extremeElementMapped = mapElement; + } + } + } + else + { + while (!r.empty) + { + MapType mapElement = mapFun(r.front); + if (selectorFun(mapElement, extremeElementMapped)) + { + extremeElement = r.front; + extremeElementMapped = mapElement; + } + r.popFront(); } - r.popFront(); } } return extremeElement; } -@safe pure nothrow unittest +@safe pure unittest { // allows a custom map to select the extremum assert([[0, 4], [1, 2]].extremum!"a[0]" == [0, 4]); @@ -1292,14 +1337,25 @@ private auto extremum(alias map = "a", alias selector = "a < b", Range, // allows a custom selector for comparison assert([[0, 4], [1, 2]].extremum!("a[0]", "a > b") == [1, 2]); assert([[0, 4], [1, 2]].extremum!("a[1]", "a > b") == [0, 4]); -} -@safe pure nothrow unittest -{ - // allow seeds + // use a custom comparator + import std.math: cmp; + assert([-2., 0, 5].extremum!cmp == 5.0); + assert([-2., 0, 2].extremum!`cmp(a, b) < 0` == -2.0); + + // combine with map + import std.range: enumerate; + assert([-3., 0, 5].enumerate.extremum!(`a.value`, cmp) == tuple(2, 5.0)); + assert([-2., 0, 2].enumerate.extremum!(`a.value`, `cmp(a, b) < 0`) == tuple(0, -2.0)); + + // seed with a custom value int[] arr; assert(arr.extremum(1) == 1); +} +@safe pure nothrow unittest +{ + // 2d seeds int[][] arr2d; assert(arr2d.extremum([1]) == [1]); @@ -1307,6 +1363,43 @@ private auto extremum(alias map = "a", alias selector = "a < b", Range, assert(extremum([2, 3, 4], 1.5) == 1.5); } +@safe pure unittest +{ + import std.range: enumerate, iota; + + // forward ranges + assert(iota(1, 5).extremum() == 1); + assert(iota(2, 5).enumerate.extremum!"a.value" == tuple(0, 2)); + + // should work with const + const(int)[] immArr = [2, 1, 3]; + assert(immArr.extremum == 1); + + // should work with immutable + immutable(int)[] immArr2 = [2, 1, 3]; + assert(immArr2.extremum == 1); + + // with strings + assert(["b", "a", "c"].extremum == "a"); + + // with all dummy ranges + import std.internal.test.dummyrange; + foreach (DummyType; AllDummyRanges) + { + DummyType d; + assert(d.extremum == 1); + } +} + +@nogc @safe nothrow pure unittest +{ + static immutable arr = [7, 3, 4, 2, 1, 8]; + assert(arr.extremum == 1); + + static immutable arr2d = [[1, 9], [3, 1], [4, 2]]; + assert(arr2d.extremum!"a[1]" == arr2d[1]); +} + // find /** Finds an individual element in an input range. Elements of $(D