From 4a03dd9bda6eaf5d1b03116bf32bfeef42e3f2af Mon Sep 17 00:00:00 2001 From: Sebastian Wilzbach Date: Sun, 26 Jun 2016 15:25:01 +0200 Subject: [PATCH 1/2] mir.random.discrete - fix for zero probilities --- source/mir/random/discrete.d | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/source/mir/random/discrete.d b/source/mir/random/discrete.d index 9259731a..c6206e9e 100644 --- a/source/mir/random/discrete.d +++ b/source/mir/random/discrete.d @@ -37,7 +37,7 @@ unittest // sample from the discrete distribution auto obs = new uint[cdPoints.length]; - foreach (i; 0..1000) + foreach (i; 0..10_000) obs[ds()]++; } @@ -81,10 +81,11 @@ struct Discrete(T) import std.range : assumeSorted; T v = uniform!("[)", T, T)(0, cdPoints[$-1], gen); - return cdPoints.assumeSorted!"a <= b".lowerBound(v).length; + return cdPoints.length - cdPoints.assumeSorted!"a < b".upperBound(v).length; } } +// test with cumulative probs unittest { import std.random : Random; @@ -95,10 +96,13 @@ unittest auto ds = discrete(cdPoints); auto obs = new uint[cdPoints.length]; - foreach (i; 0..1000) + foreach (i; 0..10_000) obs[ds(gen)]++; + + assert(obs == [1030, 1964, 1968, 4087, 951]); } +// test with cumulative count unittest { import std.random : Random; @@ -109,6 +113,25 @@ unittest auto ds = discrete(cdPoints); auto obs = new uint[cdPoints.length]; - foreach (i; 0..1000) + foreach (i; 0..10_000) + obs[ds(gen)]++; + + assert(obs == [2536, 4963, 2501]); +} + +// test with zero probabilities +unittest +{ + import std.random : Random; + auto gen = Random(42); + + // 0, 1, 2, 0, 1 + auto cdPoints = [0, 1, 3, 3, 4]; + auto ds = discrete(cdPoints); + + auto obs = new uint[cdPoints.length]; + foreach (i; 0..10_000) obs[ds(gen)]++; + + assert(obs == [0, 2536, 4963, 0, 2501]); } From ad3bde417b227dc04ecc138686849ff6d59c7d7c Mon Sep 17 00:00:00 2001 From: Sebastian Wilzbach Date: Sun, 26 Jun 2016 17:20:35 +0200 Subject: [PATCH 2/2] replace Random alias with Mt19937 --- source/mir/random/discrete.d | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/source/mir/random/discrete.d b/source/mir/random/discrete.d index c6206e9e..a448f8fe 100644 --- a/source/mir/random/discrete.d +++ b/source/mir/random/discrete.d @@ -88,8 +88,8 @@ struct Discrete(T) // test with cumulative probs unittest { - import std.random : Random; - auto gen = Random(42); + import std.random : Mt19937; + auto gen = Mt19937(42); // 10%, 20%, 20%, 40%, 10% auto cdPoints = [0.1, 0.3, 0.5, 0.9, 1]; @@ -105,8 +105,8 @@ unittest // test with cumulative count unittest { - import std.random : Random; - auto gen = Random(42); + import std.random : Mt19937; + auto gen = Mt19937(42); // 1, 2, 1 auto cdPoints = [1, 3, 4]; @@ -122,8 +122,8 @@ unittest // test with zero probabilities unittest { - import std.random : Random; - auto gen = Random(42); + import std.random : Mt19937; + auto gen = Mt19937(42); // 0, 1, 2, 0, 1 auto cdPoints = [0, 1, 3, 3, 4];