diff --git a/examples/forward_simulation.rs b/examples/forward_simulation.rs index 85eb931d..499350da 100644 --- a/examples/forward_simulation.rs +++ b/examples/forward_simulation.rs @@ -1,36 +1,25 @@ -//! Examples of tree sequence recording with neutral -//! Wright-Fisher models. -//! -//! This module provides a means of generating testing -//! code and benchmarking utilities. However, some of -//! the concepts here that are *not* public may be useful -//! to others. Feel free to copy them! use bitflags::bitflags; use clap::{value_t, value_t_or_exit, App, Arg}; -use forrustts::simplify_from_edge_buffer; -use forrustts::simplify_tables; -use forrustts::simplify_tables_without_state; -use forrustts::EdgeBuffer; use forrustts::ForrusttsError; use forrustts::IdType; use forrustts::Position; -use forrustts::SamplesInfo; -use forrustts::Segment; -use forrustts::SimplificationBuffers; -use forrustts::SimplificationFlags; -use forrustts::SimplificationOutput; -use forrustts::TableCollection; use forrustts::Time; use rand::rngs::StdRng; use rand::Rng; use rand::SeedableRng; -use rand_distr::{Geometric, Uniform}; +use rand_distr::{Exp, Uniform}; // Some of the material below seems like a candidate for a public API, // but we need to decide here if this package should provide that. // If so, then many of these types should not be here, as they have nothing // to do with Wright-Fisher itself, and are instead more general. +// Even though Position is an integer, we will use +// an exponential distribution to get the distance to +// the next crossover position. The reason for this is +// that rand_distr::Geometric has really poor performance. +type BreakpointFunction = Option>; + #[derive(Copy, Clone)] struct Parent { index: usize, @@ -38,40 +27,20 @@ struct Parent { node1: IdType, } -impl Parent { - pub const fn new(index: usize, node0: IdType, node1: IdType) -> Parent { - Parent { - index, - node0, - node1, - } - } -} - struct Birth { index: usize, parent0: Parent, parent1: Parent, } -impl Birth { - pub const fn new(index: usize, parent0: Parent, parent1: Parent) -> Birth { - Birth { - index, - parent0, - parent1, - } - } -} - type VecParent = Vec; type VecBirth = Vec; struct PopulationState { pub parents: VecParent, pub births: VecBirth, - pub edge_buffer: EdgeBuffer, - pub tables: TableCollection, + pub edge_buffer: forrustts::EdgeBuffer, + pub tables: forrustts::TableCollection, } impl PopulationState { @@ -79,8 +48,8 @@ impl PopulationState { PopulationState { parents: vec![], births: vec![], - edge_buffer: EdgeBuffer::new(), - tables: TableCollection::new(genome_length).unwrap(), + edge_buffer: forrustts::EdgeBuffer::new(), + tables: forrustts::TableCollection::new(genome_length).unwrap(), } } } @@ -94,7 +63,11 @@ fn deaths_and_parents(psurvival: f64, rng: &mut StdRng, pop: &mut PopulationStat Some(std::cmp::Ordering::Greater) => { let parent0 = pop.parents[rng.sample(random_parents)]; let parent1 = pop.parents[rng.sample(random_parents)]; - pop.births.push(Birth::new(i, parent0, parent1)); + pop.births.push(Birth { + index: i, + parent0, + parent1, + }); } Some(_) => (), None => (), @@ -102,46 +75,103 @@ fn deaths_and_parents(psurvival: f64, rng: &mut StdRng, pop: &mut PopulationStat } } -/// Decide which node to pass on from a parent. -fn mendel(rng: &mut StdRng, n0: IdType, n1: IdType) -> (IdType, IdType) { +fn mendel(pnodes: &mut (tskit::tsk_id_t, tskit::tsk_id_t), rng: &mut StdRng) { let x: f64 = rng.gen(); match x.partial_cmp(&0.5) { - Some(std::cmp::Ordering::Less) => (n1, n0), - Some(_) => (n0, n1), + Some(std::cmp::Ordering::Less) => { + std::mem::swap(&mut pnodes.0, &mut pnodes.1); + } + Some(_) => (), None => panic!("Unexpected None"), } } +fn crossover_and_record_edges( + parent: Parent, + child: IdType, + breakpoint: BreakpointFunction, + recorder: &impl Fn( + IdType, + IdType, + (Position, Position), + &mut forrustts::TableCollection, + &mut forrustts::EdgeBuffer, + ), + rng: &mut StdRng, + tables: &mut forrustts::TableCollection, + edge_buffer: &mut forrustts::EdgeBuffer, +) { + let mut pnodes = (parent.node0, parent.node1); + mendel(&mut pnodes, rng); + let mut p0 = parent.node0; + let mut p1 = parent.node1; + + if let Some(exp) = breakpoint { + let mut current_pos: Position = 0; + loop { + // TODO: gotta justify the next line... + let next_length = (rng.sample(exp) as Position) + 1; + assert!(next_length > 0); + if current_pos + next_length < tables.genome_length() { + recorder( + p0, + child, + (current_pos, current_pos + next_length), + tables, + edge_buffer, + ); + current_pos += next_length; + std::mem::swap(&mut p0, &mut p1); + } else { + recorder( + p0, + child, + (current_pos, tables.genome_length()), + tables, + edge_buffer, + ); + + break; + } + } + } else { + recorder(p0, child, (0, tables.genome_length()), tables, edge_buffer); + } +} + fn generate_births( - littler: f64, + breakpoint: BreakpointFunction, birth_time: Time, rng: &mut StdRng, - breakpoints: &mut Vec, pop: &mut PopulationState, - recorder: impl Fn((IdType, IdType), IdType, &[Position], &mut TableCollection, &mut EdgeBuffer), + recorder: &impl Fn( + IdType, + IdType, + (Position, Position), + &mut forrustts::TableCollection, + &mut forrustts::EdgeBuffer, + ), ) { for b in &pop.births { - let parent0_nodes = mendel(rng, b.parent0.node0, b.parent0.node1); - let parent1_nodes = mendel(rng, b.parent1.node0, b.parent1.node1); - // Record 2 new nodes let new_node_0: IdType = pop.tables.add_node(birth_time, 0).unwrap(); let new_node_1: IdType = pop.tables.add_node(birth_time, 0).unwrap(); - recombination_breakpoints(littler, pop.tables.genome_length(), rng, breakpoints); - recorder( - parent0_nodes, + crossover_and_record_edges( + b.parent0, new_node_0, - breakpoints, + breakpoint, + recorder, + rng, &mut pop.tables, &mut pop.edge_buffer, ); - - recombination_breakpoints(littler, pop.tables.genome_length(), rng, breakpoints); - recorder( - parent1_nodes, + crossover_and_record_edges( + b.parent1, new_node_1, - breakpoints, + breakpoint, + recorder, + rng, &mut pop.tables, &mut pop.edge_buffer, ); @@ -153,110 +183,28 @@ fn generate_births( } fn buffer_edges( - parents: (IdType, IdType), + parent: IdType, child: IdType, - breakpoints: &[Position], - tables: &mut TableCollection, - buffer: &mut EdgeBuffer, + span: (Position, Position), + _: &mut forrustts::TableCollection, + buffer: &mut forrustts::EdgeBuffer, ) { - if breakpoints.is_empty() { - buffer - .extend(parents.0, Segment::new(0, tables.genome_length(), child)) - .unwrap(); - return; - } - - // If we don't have a breakpoint at 0, add an edge - if breakpoints[0] != 0 { - buffer - .extend(parents.0, Segment::new(0, breakpoints[0], child)) - .unwrap(); - } - - for i in 1..breakpoints.len() { - let a = breakpoints[i - 1]; - let b = if i < (breakpoints.len() - 1) { - breakpoints[i] - } else { - tables.genome_length() - }; - if i % 2 == 0 { - buffer.extend(parents.0, Segment::new(a, b, child)).unwrap(); - } else { - buffer.extend(parents.1, Segment::new(a, b, child)).unwrap(); - } - } + buffer + .extend(parent, forrustts::Segment::new(span.0, span.1, child)) + .unwrap(); } fn record_edges( - parents: (IdType, IdType), + parent: IdType, child: IdType, - breakpoints: &[Position], - tables: &mut TableCollection, - _: &mut EdgeBuffer, + span: (Position, Position), + tables: &mut forrustts::TableCollection, + _: &mut forrustts::EdgeBuffer, ) { - if breakpoints.is_empty() { - tables - .add_edge(0, tables.genome_length(), parents.0, child) - .unwrap(); - return; - } - - // If we don't have a breakpoint at 0, add an edge - if breakpoints[0] != 0 { - tables - .add_edge(0, breakpoints[0], parents.0, child) - .unwrap(); - } - - for i in 1..breakpoints.len() { - let a = breakpoints[i - 1]; - let b = if i < (breakpoints.len() - 1) { - breakpoints[i] - } else { - tables.genome_length() - }; - if i % 2 == 0 { - tables.add_edge(a, b, parents.0, child).unwrap(); - } else { - tables.add_edge(a, b, parents.1, child).unwrap(); - } - } -} - -fn recombination_breakpoints( - littler: f64, - maxlen: Position, - rng: &mut StdRng, - breakpoints: &mut Vec, -) { - breakpoints.clear(); - match littler.partial_cmp(&0.0) { - Some(std::cmp::Ordering::Greater) => { - let geom = match Geometric::new(littler / maxlen as f64) { - Ok(g) => g, - Err(e) => panic!("{}", e), - }; - let mut current_pos: Position = 0; - loop { - let next_length = rng.sample(geom) as Position; - if current_pos + next_length < maxlen { - current_pos += next_length; - breakpoints.push(current_pos); - } else { - break; - } - } - } - Some(_) => {} - None => (), - } - if !breakpoints.is_empty() { - breakpoints.push(Position::MAX); - } + tables.add_edge(span.0, span.1, parent, child).unwrap(); } -fn fill_samples(parents: &[Parent], samples: &mut SamplesInfo) { +fn fill_samples(parents: &[Parent], samples: &mut forrustts::SamplesInfo) { samples.samples.clear(); for p in parents { samples.samples.push(p.node0); @@ -266,17 +214,17 @@ fn fill_samples(parents: &[Parent], samples: &mut SamplesInfo) { fn sort_and_simplify( flags: SimulationFlags, - simplification_flags: SimplificationFlags, - samples: &SamplesInfo, - state: &mut SimplificationBuffers, + simplification_flags: forrustts::SimplificationFlags, + samples: &forrustts::SamplesInfo, + state: &mut forrustts::SimplificationBuffers, pop: &mut PopulationState, - output: &mut SimplificationOutput, + output: &mut forrustts::SimplificationOutput, ) { if !flags.contains(SimulationFlags::BUFFER_EDGES) { pop.tables .sort_tables(forrustts::TableSortingFlags::empty()); if flags.contains(SimulationFlags::USE_STATE) { - simplify_tables( + forrustts::simplify_tables( samples, simplification_flags, state, @@ -285,11 +233,16 @@ fn sort_and_simplify( ) .unwrap(); } else { - simplify_tables_without_state(samples, simplification_flags, &mut pop.tables, output) - .unwrap(); + forrustts::simplify_tables_without_state( + samples, + simplification_flags, + &mut pop.tables, + output, + ) + .unwrap(); } } else { - simplify_from_edge_buffer( + forrustts::simplify_from_edge_buffer( samples, simplification_flags, state, @@ -303,11 +256,11 @@ fn sort_and_simplify( fn simplify_and_remap_nodes( flags: SimulationFlags, - simplification_flags: SimplificationFlags, - samples: &mut SamplesInfo, - state: &mut SimplificationBuffers, + simplification_flags: forrustts::SimplificationFlags, + samples: &mut forrustts::SamplesInfo, + state: &mut forrustts::SimplificationBuffers, pop: &mut PopulationState, - output: &mut SimplificationOutput, + output: &mut forrustts::SimplificationOutput, ) { fill_samples(&pop.parents, samples); sort_and_simplify(flags, simplification_flags, samples, state, pop, output); @@ -333,84 +286,56 @@ fn validate_simplification_interval(x: Time) -> Time { x } -// NOTE: this function is a copy of the simulation -// found in fwdpp/examples/edge_buffering.cc - -/// Parameters of a population to be evolved by -/// [``neutral_wf``]. pub struct PopulationParams { - /// Diploid population size pub size: u32, - /// Genome length. Easiest to think "base pairs", - /// but more abstract concepts are valid. pub genome_length: Position, - /// Mean number of crossovers (per mating). - pub littler: f64, - /// Survival probability. Must be 0 <= p < 1. + pub breakpoint: BreakpointFunction, pub psurvival: f64, } impl PopulationParams { - /// Create a new instance - /// - /// # Limitations - /// - /// As of 0.1.0, input values are not validated. - pub fn new(size: u32, genome_length: Position, littler: f64, psurvival: f64) -> Self { + pub fn new(size: u32, genome_length: Position, xovers: f64, psurvival: f64) -> Self { PopulationParams { size, genome_length, - littler, + breakpoint: match xovers.partial_cmp(&0.0) { + Some(std::cmp::Ordering::Greater) => { + Some(Exp::new(xovers / genome_length as f64).unwrap()) + } + Some(_) => None, + None => None, + }, psurvival, } } } bitflags! { - /// Bitwise flag tweaking the behavior of the - /// simplification algorithm. #[derive(Default)] pub struct SimulationFlags: u32 { - /// If set, and [``BUFFER_EDGES``] is not set, - /// then simplification will use a reusable set - /// of buffers for each call. Otherwise, - /// these buffers will be allocated each time - /// simplification happens. + // If set, and BUFFER_EDGES is not set, + // then simplification will use a reusable set + // of buffers for each call. Otherwise, + // these buffers will be allocated each time + // simplification happens. const USE_STATE = 1 << 0; - /// If set, edge buffering will be used. - /// If not set, then the standard "record - /// and sort" method will be used. + // If set, edge buffering will be used. + // If not set, then the standard "record + // and sort" method will be used. const BUFFER_EDGES = 1 << 1; } } -/// Parameters of a simulation to be executed -/// by [``neutral_wf``]. pub struct SimulationParams { - /// How often to apply the simplification algorithm. - /// If ``None``, then simplification never happens. - /// If the value is ``Some(Time)``, then simplification - /// will occur after that many time steps. pub simplification_interval: Option