diff --git a/std/algorithm/iteration.d b/std/algorithm/iteration.d index 58298117cdb..67e74fbcfb4 100644 --- a/std/algorithm/iteration.d +++ b/std/algorithm/iteration.d @@ -18,6 +18,8 @@ $(T2 cumulativeFold, $(D cumulativeFold!((a, b) => a + b)([1, 2, 3, 4])) returns a lazily-evaluated range containing the successive reduced values `1`, `3`, `6`, `10`.) +$(T2 cumulativeSum, + Same as $(D cumulativeFold), but specialized for accurate summation.) $(T2 each, $(D each!writeln([1, 2, 3])) eagerly prints the numbers $(D 1), $(D 2) and $(D 3) on their own lines.) @@ -3483,6 +3485,315 @@ The number of seeds must be correspondingly increased. } } +/++ +Performs a summation of the given $(REF_ALTTEXT input range, isInputRange, +std, range, primitives) `r`, and provides the intermediate results of the +summation as an $(REF_ALTTEXT input range, isInputRange, std, range, +primitives). `cumulativeSum` is conceptually equivalent to +`cumulativeFold!((a, b) => a + b)`, but for floating point summations the +$(HTTP en.wikipedia.org/wiki/Kahan_summation, Kahan summation) algorithm is +used to reduce accuracy loss from cancellation errors. + +When called without a seed, the seed type is deduced from the $(REF_ALTTEXT +element type, ElementType, std,range,primitives) of `r` and the seed value is +0. If the $(REF_ALTTEXT element type, ElementType, std, range, primitives) of +`r` is a $(REF_ALTTEXT floating point type, isFloatingPoint, std, traits), then +the seed type will be deduced to be the most precise type available from either +`double` or `real`. + +Params: + r = any $(REF_ALTTEXT input range, isInputRange, std, range, primitives) + s = a seed value that gives the initial value of the summation + +Returns: + An $(REF_ALTTEXT input range, isInputRange, std, range, primitives) + containing the intermediate results of the summation of `r`. + +See_Also: + $(HTTPS en.wikipedia.org/wiki/Prefix_sum, Prefix Sum) + + $(LREF cumulativeFold) provides the intermediate results of generic + reduction operations on ranges. + + $(LREF sum) performs a summation of a range without providing intermediate + results. + +/ +auto cumulativeSum(Range)(Range r) + if (isInputRange!Range && __traits(compiles, r.front + r.front)) +{ + static if (isFloatingPoint!(ElementType!Range)) + { + // Most precise seed type available from either double or real. + alias Seed = typeof(0.0 + r.front); + } + else + { + // Deduce seed type from the result of a single addition. + alias Seed = typeof(r.front + r.front); + } + + return r.cumulativeSum(Seed(0)); +} + +/// Ditto +auto cumulativeSum(Range, Seed)(Range r, Seed s) + if (isInputRange!Range && __traits(compiles, r.front + r.front)) +{ + static if (isFloatingPoint!Seed) + { + static struct Result + { + this(Range r, Seed s) + { + _r = r; + if (_r.empty) return; + _s = s; + sumFront; + } + + @property + auto front() + in + { + assert(!empty, + "Attempting to fetch the front of an empty cumulativeSum"); + } + body + { + return _s; + } + + void popFront() + in + { + assert(!empty, + "Attempting to popFront an empty cumulativeSum"); + } + body + { + _r.popFront; + if (_r.empty) return; + sumFront; + } + + static if (isInfinite!Range) + { + enum empty = false; + } + else + { + @property + bool empty() + { + return _r.empty; + } + } + + static if (isForwardRange!Range) + { + @property + auto save() + { + auto result = this; + result._r = _r.save; + return result; + } + } + + static if (hasLength!Range) + { + @property + size_t length() + { + return _r.length; + } + } + private: + Range _r; + Seed _s; + Seed _c = 0; + + void sumFront() + { + // One iteration of Kahan summation. + immutable y = _r.front - _c; + immutable t = _s + y; + _c = (t - _s) - y; + _s = t; + } + } + + return Result(r, s); + } + else + { + // Default to naive summation for integral values. + return r.cumulativeFold!((a, b) => a + b)(s); + } +} + +/// +@safe pure nothrow +unittest +{ + import std.algorithm.comparison : equal; + import std.range : iota, repeat; + + // Partial sum of integral values: + assert(cumulativeSum([1, 2, 3, 4, 5]).equal([1, 3, 6, 10, 15])); + + // Using ranges and UFCS: + assert(iota(1, 6).cumulativeSum.equal([1, 3, 6, 10, 15])); + + // With seed value: + assert(iota(1, 6).cumulativeSum(-15).equal([-14, -12, -9, -5, 0])); + + + // Partial sum of floating point values: + assert(cumulativeSum([1.0, 2.0, 3.0, 4.0, 5.0]) + .equal([1.0, 3.0, 6.0, 10.0, 15.0])); + + // With seed value: + assert(cumulativeSum([1.0, 2.0, 3.0, 4.0, 5.0], -15.0) + .equal([-14.0, -12.0, -9.0, -5.0, 0.0])); + + + // Partial sum with integral promotion: + assert(cumulativeSum([false, true, true, false, true]) + .equal([0, 1, 2, 2, 3])); + + // Similarly, a seed can be used to force floating point summation: + assert(cumulativeSum([false, true, true, false, true], 0.0) + .equal([0.0, 1.0, 2.0, 2.0, 3.0])); + + + // The result may overflow: + assert(uint.max.repeat(3).cumulativeSum + .equal([4294967295U, 4294967294U, 4294967293U])); + + // But a seed can be used to change the sumation primitive: + assert(uint.max.repeat(3).cumulativeSum(ulong.init) + .equal([4294967295UL, 8589934590UL, 12884901885UL])); +} + +/++ +`cumulativeSum` uses Kahan summation to give more accurate results than +naive summation for ranges of floating point values. + +/ +@safe pure nothrow +unittest +{ + import std.math : approxEqual; + + // Despite summing 'large' and 'small' numbers the loss of significance is + // a non-issue. + assert(cumulativeSum([10000, 3.14159, 2.71828, 1.41421, 1.61803, -10000]) + .approxEqual([10000, 10003.1, 10005.9, 10007.3, 10008.9, 8.89211])); + + // Another example with a 'large' seed value. + assert(cumulativeSum([6.28318, 1.73205, 3.33333, 2.23606, -10000], 10000.0) + .approxEqual([10006.3, 10008.0, 10011.3, 10013.6, 13.5846])); + + // A more extreme example. + assert(cumulativeSum([71850, 1.594e-11, 7.91182e-11, 2.36169e-11, -71850]) + .approxEqual([71850, 71850, 71850, 71850, 1.18675e-10])); +} + +@safe pure nothrow +unittest +{ + import std.range.primitives : ElementType; + import std.algorithm.comparison : equal; + + // Integral types: + + static assert(is(ElementType!(typeof(cumulativeSum([cast(byte)1]))) == int)); + static assert(is(ElementType!(typeof(cumulativeSum([cast(ubyte)1]))) == int)); + static assert(is(ElementType!(typeof(cumulativeSum([1, 2, 3, 4]))) == int)); + static assert(is(ElementType!(typeof(cumulativeSum([1U, 2U, 3U, 4U]))) == uint)); + static assert(is(ElementType!(typeof(cumulativeSum([1L, 2L, 3L, 4L]))) == long)); + static assert(is(ElementType!(typeof(cumulativeSum([1UL, 2UL, 3UL, 4UL]))) == ulong)); + + int[] empty; + assert(cumulativeSum(empty).empty); + assert(cumulativeSum([42]).equal([42])); + assert(cumulativeSum([42, 43]).equal([42, 85])); + assert(cumulativeSum([42, 43, 44]).equal([42, 85, 129])); + assert(cumulativeSum([42, 43, 44, 45]).equal([42, 85, 129, 174])); +} + +@safe pure nothrow +unittest +{ + import std.range.primitives : ElementType; + import std.algorithm.comparison : equal; + + // Floating point types: + + static assert(is(ElementType!(typeof(cumulativeSum([1F, 2F, 3F, 4F]))) == double)); + static assert(is(ElementType!(typeof(cumulativeSum([1.0, 2.0, 3.0, 4.0]))) == double)); + static assert(is(ElementType!(typeof(cumulativeSum([1.0L, 2.0L, 3.0L, 4.0L]))) == real)); + const(float[]) a = [1F, 2F, 3F, 4F]; + static assert(is(ElementType!(typeof(cumulativeSum(a))) == double)); + const(float)[] b = [1F, 2F, 3F, 4F]; + static assert(is(ElementType!(typeof(cumulativeSum(b))) == double)); + + double[] empty; + assert(cumulativeSum(empty).empty); + assert(cumulativeSum([42.0]).equal([42])); + assert(cumulativeSum([42.0, 43.0]).equal([42, 85])); + assert(cumulativeSum([42.0, 43.0, 44.0]).equal([42, 85, 129])); + assert(cumulativeSum([42.0, 43.0, 44.0, 45.5]) + .equal([42, 85, 129, 174.5])); +} + +@safe @nogc pure nothrow +unittest +{ + import std.algorithm.comparison : equal; + import std.range : iota, repeat; + + foreach (n; iota(50)) + { + assert(repeat(1, n).cumulativeSum(-1.0).equal(iota(n))); + } +} + +@safe pure nothrow +unittest +{ + import std.algorithm.comparison : equal; + import std.internal.test.dummyrange : AllDummyRanges, propagatesLength, + propagatesRangeType, RangeType; + import std.algorithm.iteration : map, joiner; + import std.range : isForwardRange, chunks; + + foreach (DummyType; AllDummyRanges) + { + DummyType d; + + // Test floating point values as integral values are handled by + // cumulativeFold. + auto f = d.map!(n => cast(double)n); + + static if (isForwardRange!(typeof(f))) + { + assert(f.chunks(1).map!cumulativeSum.joiner.equal(f)); + } + + auto s = f.cumulativeSum; + + static assert(propagatesLength!(typeof(s), typeof(f))); + + static if (DummyType.rt <= RangeType.Forward) + { + static assert(propagatesRangeType!(typeof(s), typeof(f))); + } + + assert(s.equal([1, 3, 6, 10, 15, 21, 28, 36, 45, 55])); + } +} + // splitter /** Lazily splits a range using an element as a separator. This can be used with diff --git a/std/algorithm/package.d b/std/algorithm/package.d index 5be46eb480d..ac53d97e4e7 100644 --- a/std/algorithm/package.d +++ b/std/algorithm/package.d @@ -70,6 +70,7 @@ $(TR $(TDNW Iteration) $(SUBREF iteration, cacheBidirectional) $(SUBREF iteration, chunkBy) $(SUBREF iteration, cumulativeFold) + $(SUBREF iteration, cumulativeSum) $(SUBREF iteration, each) $(SUBREF iteration, filter) $(SUBREF iteration, filterBidirectional)