diff --git a/source/mir/random/discrete.d b/source/mir/random/discrete.d index 9259731a..a448f8fe 100644 --- a/source/mir/random/discrete.d +++ b/source/mir/random/discrete.d @@ -37,7 +37,7 @@ unittest // sample from the discrete distribution auto obs = new uint[cdPoints.length]; - foreach (i; 0..1000) + foreach (i; 0..10_000) obs[ds()]++; } @@ -81,34 +81,57 @@ struct Discrete(T) import std.range : assumeSorted; T v = uniform!("[)", T, T)(0, cdPoints[$-1], gen); - return cdPoints.assumeSorted!"a <= b".lowerBound(v).length; + return cdPoints.length - cdPoints.assumeSorted!"a < b".upperBound(v).length; } } +// test with cumulative probs unittest { - import std.random : Random; - auto gen = Random(42); + import std.random : Mt19937; + auto gen = Mt19937(42); // 10%, 20%, 20%, 40%, 10% auto cdPoints = [0.1, 0.3, 0.5, 0.9, 1]; auto ds = discrete(cdPoints); auto obs = new uint[cdPoints.length]; - foreach (i; 0..1000) + foreach (i; 0..10_000) obs[ds(gen)]++; + + assert(obs == [1030, 1964, 1968, 4087, 951]); } +// test with cumulative count unittest { - import std.random : Random; - auto gen = Random(42); + import std.random : Mt19937; + auto gen = Mt19937(42); // 1, 2, 1 auto cdPoints = [1, 3, 4]; auto ds = discrete(cdPoints); auto obs = new uint[cdPoints.length]; - foreach (i; 0..1000) + foreach (i; 0..10_000) + obs[ds(gen)]++; + + assert(obs == [2536, 4963, 2501]); +} + +// test with zero probabilities +unittest +{ + import std.random : Mt19937; + auto gen = Mt19937(42); + + // 0, 1, 2, 0, 1 + auto cdPoints = [0, 1, 3, 3, 4]; + auto ds = discrete(cdPoints); + + auto obs = new uint[cdPoints.length]; + foreach (i; 0..10_000) obs[ds(gen)]++; + + assert(obs == [0, 2536, 4963, 0, 2501]); }