Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 311 additions & 0 deletions std/algorithm/iteration.d
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ $(T2 cumulativeFold,
$(D cumulativeFold!((a, b) => a + b)([1, 2, 3, 4])) returns a
lazily-evaluated range containing the successive reduced values `1`,
`3`, `6`, `10`.)
$(T2 cumulativeSum,
Same as $(D cumulativeFold), but specialized for accurate summation.)
$(T2 each,
$(D each!writeln([1, 2, 3])) eagerly prints the numbers $(D 1), $(D 2)
and $(D 3) on their own lines.)
Expand Down Expand Up @@ -3483,6 +3485,315 @@ The number of seeds must be correspondingly increased.
}
}

/++
Performs a summation of the given $(REF_ALTTEXT input range, isInputRange,
std, range, primitives) `r`, and provides the intermediate results of the
summation as an $(REF_ALTTEXT input range, isInputRange, std, range,
primitives). `cumulativeSum` is conceptually equivalent to
`cumulativeFold!((a, b) => a + b)`, but for floating point summations the
$(HTTP en.wikipedia.org/wiki/Kahan_summation, Kahan summation) algorithm is
used to reduce accuracy loss from cancellation errors.

When called without a seed, the seed type is deduced from the $(REF_ALTTEXT
element type, ElementType, std,range,primitives) of `r` and the seed value is
0. If the $(REF_ALTTEXT element type, ElementType, std, range, primitives) of
`r` is a $(REF_ALTTEXT floating point type, isFloatingPoint, std, traits), then
the seed type will be deduced to be the most precise type available from either
`double` or `real`.

Params:
r = any $(REF_ALTTEXT input range, isInputRange, std, range, primitives)
s = a seed value that gives the initial value of the summation

Returns:
An $(REF_ALTTEXT input range, isInputRange, std, range, primitives)
containing the intermediate results of the summation of `r`.

See_Also:
$(HTTPS en.wikipedia.org/wiki/Prefix_sum, Prefix Sum)

$(LREF cumulativeFold) provides the intermediate results of generic
reduction operations on ranges.

$(LREF sum) performs a summation of a range without providing intermediate
results.
+/
auto cumulativeSum(Range)(Range r)
if (isInputRange!Range && __traits(compiles, r.front + r.front))
{
static if (isFloatingPoint!(ElementType!Range))
{
// Most precise seed type available from either double or real.
alias Seed = typeof(0.0 + r.front);
}
else
{
// Deduce seed type from the result of a single addition.
alias Seed = typeof(r.front + r.front);
}

return r.cumulativeSum(Seed(0));
}

/// Ditto
auto cumulativeSum(Range, Seed)(Range r, Seed s)
if (isInputRange!Range && __traits(compiles, r.front + r.front))
{
static if (isFloatingPoint!Seed)
{
static struct Result
{
this(Range r, Seed s)
{
_r = r;
if (_r.empty) return;
_s = s;
sumFront;
}

@property
auto front()
in
{
assert(!empty,
"Attempting to fetch the front of an empty cumulativeSum");
}
body
{
return _s;
}

void popFront()
in
{
assert(!empty,
"Attempting to popFront an empty cumulativeSum");
}
body
{
_r.popFront;
if (_r.empty) return;
sumFront;
}

static if (isInfinite!Range)
{
enum empty = false;
}
else
{
@property
bool empty()
{
return _r.empty;
}
}

static if (isForwardRange!Range)
{
@property
auto save()
{
auto result = this;
result._r = _r.save;
return result;
}
}

static if (hasLength!Range)
{
@property
size_t length()
{
return _r.length;
}
}
private:
Range _r;
Seed _s;
Seed _c = 0;

void sumFront()
{
// One iteration of Kahan summation.
immutable y = _r.front - _c;
immutable t = _s + y;
_c = (t - _s) - y;
_s = t;
}
}

return Result(r, s);
}
else
{
// Default to naive summation for integral values.
return r.cumulativeFold!((a, b) => a + b)(s);
}
}

///
@safe pure nothrow
unittest
{
import std.algorithm.comparison : equal;
import std.range : iota, repeat;

// Partial sum of integral values:
assert(cumulativeSum([1, 2, 3, 4, 5]).equal([1, 3, 6, 10, 15]));

// Using ranges and UFCS:
assert(iota(1, 6).cumulativeSum.equal([1, 3, 6, 10, 15]));

// With seed value:
assert(iota(1, 6).cumulativeSum(-15).equal([-14, -12, -9, -5, 0]));


// Partial sum of floating point values:
assert(cumulativeSum([1.0, 2.0, 3.0, 4.0, 5.0])
.equal([1.0, 3.0, 6.0, 10.0, 15.0]));

// With seed value:
assert(cumulativeSum([1.0, 2.0, 3.0, 4.0, 5.0], -15.0)
.equal([-14.0, -12.0, -9.0, -5.0, 0.0]));


// Partial sum with integral promotion:
assert(cumulativeSum([false, true, true, false, true])
.equal([0, 1, 2, 2, 3]));

// Similarly, a seed can be used to force floating point summation:
assert(cumulativeSum([false, true, true, false, true], 0.0)
.equal([0.0, 1.0, 2.0, 2.0, 3.0]));


// The result may overflow:
assert(uint.max.repeat(3).cumulativeSum
.equal([4294967295U, 4294967294U, 4294967293U]));

// But a seed can be used to change the sumation primitive:
assert(uint.max.repeat(3).cumulativeSum(ulong.init)
.equal([4294967295UL, 8589934590UL, 12884901885UL]));
}

/++
`cumulativeSum` uses Kahan summation to give more accurate results than
naive summation for ranges of floating point values.
+/
@safe pure nothrow
unittest
{
import std.math : approxEqual;

// Despite summing 'large' and 'small' numbers the loss of significance is
// a non-issue.
assert(cumulativeSum([10000, 3.14159, 2.71828, 1.41421, 1.61803, -10000])
.approxEqual([10000, 10003.1, 10005.9, 10007.3, 10008.9, 8.89211]));

// Another example with a 'large' seed value.
assert(cumulativeSum([6.28318, 1.73205, 3.33333, 2.23606, -10000], 10000.0)
.approxEqual([10006.3, 10008.0, 10011.3, 10013.6, 13.5846]));

// A more extreme example.
assert(cumulativeSum([71850, 1.594e-11, 7.91182e-11, 2.36169e-11, -71850])
.approxEqual([71850, 71850, 71850, 71850, 1.18675e-10]));
}

@safe pure nothrow
unittest
{
import std.range.primitives : ElementType;
import std.algorithm.comparison : equal;

// Integral types:

static assert(is(ElementType!(typeof(cumulativeSum([cast(byte)1]))) == int));
static assert(is(ElementType!(typeof(cumulativeSum([cast(ubyte)1]))) == int));
static assert(is(ElementType!(typeof(cumulativeSum([1, 2, 3, 4]))) == int));
static assert(is(ElementType!(typeof(cumulativeSum([1U, 2U, 3U, 4U]))) == uint));
static assert(is(ElementType!(typeof(cumulativeSum([1L, 2L, 3L, 4L]))) == long));
static assert(is(ElementType!(typeof(cumulativeSum([1UL, 2UL, 3UL, 4UL]))) == ulong));

int[] empty;
assert(cumulativeSum(empty).empty);
assert(cumulativeSum([42]).equal([42]));
assert(cumulativeSum([42, 43]).equal([42, 85]));
assert(cumulativeSum([42, 43, 44]).equal([42, 85, 129]));
assert(cumulativeSum([42, 43, 44, 45]).equal([42, 85, 129, 174]));
}

@safe pure nothrow
unittest
{
import std.range.primitives : ElementType;
import std.algorithm.comparison : equal;

// Floating point types:

static assert(is(ElementType!(typeof(cumulativeSum([1F, 2F, 3F, 4F]))) == double));
static assert(is(ElementType!(typeof(cumulativeSum([1.0, 2.0, 3.0, 4.0]))) == double));
static assert(is(ElementType!(typeof(cumulativeSum([1.0L, 2.0L, 3.0L, 4.0L]))) == real));
const(float[]) a = [1F, 2F, 3F, 4F];
static assert(is(ElementType!(typeof(cumulativeSum(a))) == double));
const(float)[] b = [1F, 2F, 3F, 4F];
static assert(is(ElementType!(typeof(cumulativeSum(b))) == double));

double[] empty;
assert(cumulativeSum(empty).empty);
assert(cumulativeSum([42.0]).equal([42]));
assert(cumulativeSum([42.0, 43.0]).equal([42, 85]));
assert(cumulativeSum([42.0, 43.0, 44.0]).equal([42, 85, 129]));
assert(cumulativeSum([42.0, 43.0, 44.0, 45.5])
.equal([42, 85, 129, 174.5]));
}

@safe @nogc pure nothrow
unittest
{
import std.algorithm.comparison : equal;
import std.range : iota, repeat;

foreach (n; iota(50))
{
assert(repeat(1, n).cumulativeSum(-1.0).equal(iota(n)));
}
}

@safe pure nothrow
unittest
{
import std.algorithm.comparison : equal;
import std.internal.test.dummyrange : AllDummyRanges, propagatesLength,
propagatesRangeType, RangeType;
import std.algorithm.iteration : map, joiner;
import std.range : isForwardRange, chunks;

foreach (DummyType; AllDummyRanges)
{
DummyType d;

// Test floating point values as integral values are handled by
// cumulativeFold.
auto f = d.map!(n => cast(double)n);

static if (isForwardRange!(typeof(f)))
{
assert(f.chunks(1).map!cumulativeSum.joiner.equal(f));
}

auto s = f.cumulativeSum;

static assert(propagatesLength!(typeof(s), typeof(f)));

static if (DummyType.rt <= RangeType.Forward)
{
static assert(propagatesRangeType!(typeof(s), typeof(f)));
}

assert(s.equal([1, 3, 6, 10, 15, 21, 28, 36, 45, 55]));
}
}

// splitter
/**
Lazily splits a range using an element as a separator. This can be used with
Expand Down
1 change: 1 addition & 0 deletions std/algorithm/package.d
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ $(TR $(TDNW Iteration)
$(SUBREF iteration, cacheBidirectional)
$(SUBREF iteration, chunkBy)
$(SUBREF iteration, cumulativeFold)
$(SUBREF iteration, cumulativeSum)
$(SUBREF iteration, each)
$(SUBREF iteration, filter)
$(SUBREF iteration, filterBidirectional)
Expand Down