diff --git a/src/include/rng.hxx b/src/include/rng.hxx index c8381fa72..9b8c7236e 100644 --- a/src/include/rng.hxx +++ b/src/include/rng.hxx @@ -15,8 +15,8 @@ namespace rng namespace detail { using Generator = std::default_random_engine; -using ProcessGeneratorPtr = std::shared_ptr; -ProcessGeneratorPtr process_generator; +using GeneratorPtr = std::shared_ptr; +GeneratorPtr process_generator; int get_process_seed() { @@ -25,7 +25,7 @@ int get_process_seed() return 10000 * (rank + 1); } -ProcessGeneratorPtr get_process_generator() +GeneratorPtr get_process_generator() { if (!process_generator) { process_generator = std::make_shared(get_process_seed()); @@ -33,6 +33,12 @@ ProcessGeneratorPtr get_process_generator() return process_generator; } +GeneratorPtr get_constant_generator(int seed) +{ + // FIXME this doesn't need to be a shared pointer + return std::make_shared(seed); +} + } // namespace detail // ====================================================================== @@ -41,6 +47,10 @@ ProcessGeneratorPtr get_process_generator() template struct Uniform { + Uniform(Real min, Real max, int seed) + : dist(min, max), gen(detail::get_constant_generator(seed)) + {} + Uniform(Real min, Real max) : dist(min, max), gen(detail::get_process_generator()) {} @@ -50,7 +60,7 @@ struct Uniform Real get() { return dist(*gen); } private: - detail::ProcessGeneratorPtr gen; + detail::GeneratorPtr gen; std::uniform_real_distribution dist; }; @@ -60,9 +70,14 @@ private: template struct Normal { + Normal(Real mean, Real stdev, int seed) + : dist(mean, stdev), gen(detail::get_constant_generator(seed)) + {} + Normal(Real mean, Real stdev) : dist(mean, stdev), gen(detail::get_process_generator()) {} + Normal() : Normal(0, 1) {} Real get() { return dist(*gen); } @@ -74,7 +89,7 @@ struct Normal } private: - detail::ProcessGeneratorPtr gen; + detail::GeneratorPtr gen; std::normal_distribution dist; };