diff --git a/src/compression.rs b/src/compression.rs index 00bb2245..5e1c8af9 100644 --- a/src/compression.rs +++ b/src/compression.rs @@ -1312,6 +1312,15 @@ fn get_ivars_expansions(original_pattern: &Pattern, arg_of_loc: &FxHashMap, arg_of_loc_2: &FxHashMap) -> Vec { + let locs: Vec = original_pattern.match_locations.iter() + .filter(|loc:&&Idx| arg_of_loc_1[loc].shifted_id == + arg_of_loc_2[loc].shifted_id + && !invalid_metavar_location(shared, arg_of_loc_1[loc].shifted_id) + ).cloned().collect(); + locs +} + /// A finished abstraction #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/src/expand_variable.rs b/src/expand_variable.rs index 0c0c45ed..1593b52b 100644 --- a/src/expand_variable.rs +++ b/src/expand_variable.rs @@ -7,7 +7,7 @@ pub fn add_variable_at(p: &mut Pattern, at_loc: usize, var_id: i32) { } } -pub fn remove_variable_at(p: &mut Pattern, var_id: usize) -> Vec { +pub fn remove_variable_at(p: &mut Pattern, var_id: usize, expands_to: &mut ExpandsTo) -> Vec { let mut zids = Vec::new(); // remove the variable from the arg choices p.arg_choices.retain(|x: &LabelledZId| { @@ -24,6 +24,12 @@ pub fn remove_variable_at(p: &mut Pattern, var_id: usize) -> Vec { x.ivar -= 1; // decrement the ivar index for all variables after the one we're removing } }); + if let ExpandsTo::IVar(i) = expands_to { + assert!(*i != var_id as i32, "ExpandsTo::IVar should not be the variable we're removing"); + if *i > var_id as i32 { + *i -= 1; + } + } // remove the variable from the first_zid_of_ivar p.first_zid_of_ivar.remove(var_id); zids @@ -46,7 +52,25 @@ pub fn check_consistency(shared: &SharedData, p: &Pattern) { // check that the ivar is within bounds let arg_of_loc = &shared.arg_of_zid_node[labeled.zid]; for loc in p.match_locations.iter() { - assert!(arg_of_loc.contains_key(loc), "Variable {} at location {} is not consistent with shared data", labeled.ivar, loc); + assert!(arg_of_loc.contains_key(loc), "Variable id={}, zid={} at location {} is not consistent with shared data", labeled.ivar, labeled.zid, loc); + } + } + for (ivar, zid) in p.first_zid_of_ivar.iter().enumerate() { + for labeled in p.arg_choices.iter() { + if labeled.ivar == ivar { + // check that they expand to the same thing + let arg_of_loc_1 = &shared.arg_of_zid_node[labeled.zid]; + let arg_of_loc_2 = &shared.arg_of_zid_node[*zid]; + // println!("Checking consistency for variable id={} (zid={} vs zid={})", ivar, zid, labeled.zid); + for loc in p.match_locations.iter() { + assert!(arg_of_loc_1.contains_key(loc) && arg_of_loc_2.contains_key(loc), + "Variable id={} at location {} is not consistent with shared data: {:?} vs {:?}", ivar, loc, arg_of_loc_1.get(loc), arg_of_loc_2.get(loc)); + assert_eq!(arg_of_loc_1[loc].shifted_id, arg_of_loc_2[loc].shifted_id, + "Variable id={} at location {} has different shifted ids: {} vs {}", ivar, loc, arg_of_loc_1[loc].shifted_id, arg_of_loc_2[loc].shifted_id); + assert_eq!(arg_of_loc_1[loc].expands_to, arg_of_loc_2[loc].expands_to, + "Variable id={} at location {} expands to different things: {} vs {}", ivar, loc, arg_of_loc_1[loc].expands_to, arg_of_loc_2[loc].expands_to); + } + } } } // for i in 0..num_vars { @@ -71,11 +95,13 @@ pub fn perform_expansion_variable( // if tracked { found_tracked = true; } // if shared.cfg.follow_prune && !tracked { return None; } let mut pattern = pattern; + let mut expands_to = expands_to; - // check_consistency(shared, original_pattern); + // check_consistency(shared, &pattern); + // println!("expands_to: {:?}", expands_to); + // println!("pattern: {:?}", pattern); // update the body utility - let body_utility = pattern.body_utility + compute_body_utility_change(shared, &expands_to); // assert!(shared.cfg.no_opt_upper_bound || !holes_after_pop.is_empty() || !original_pattern.arg_choices.is_empty() || expands_to.has_holes() || expands_to.is_ivar(), // "unexpected arity 0 invention: upper bounds + priming with arity 0 inventions should have prevented this"); @@ -85,12 +111,18 @@ pub fn perform_expansion_variable( // build our new pattern with all the variables we've just defined. Copy in the argchoices and prefixes // from the old pattern. // new_pattern.match_locations = locs; - pattern.body_utility = body_utility; // println!("targeting ivar: {}", variable_ivar); // println!("expands_to: {:?}", expands_to); // println!("original: {:?}", new_pattern); - let variable_zids: Vec = remove_variable_at(&mut pattern, variable_ivar); + let variable_zids: Vec = remove_variable_at(&mut pattern, variable_ivar, &mut expands_to); + + let body_utility = pattern.body_utility + compute_body_utility_change(shared, &expands_to) * variable_zids.len() as i32; + pattern.body_utility = body_utility; + + // println!("pattern after remove: {:?}", pattern); + // println!("expands_to after remove: {:?}", expands_to); + // println!("variable_zids: {:?}", variable_zids); // println!("after removing variable: {:?}", new_pattern); // println!("variable_zids: {:?}", variable_zids); @@ -115,10 +147,11 @@ pub fn perform_expansion_variable( add_variable_at(&mut pattern, variable_zid, i); } } + // println!("pattern after add: {:?}", pattern); // println!("after adding variable: {:?}", new_pattern); - // check_consistency(shared, &new_pattern); + // check_consistency(shared, &pattern); Some (pattern) diff --git a/src/smc.rs b/src/smc.rs index 2e3bc932..c9d220c7 100644 --- a/src/smc.rs +++ b/src/smc.rs @@ -4,6 +4,92 @@ use rand::SeedableRng; use rustc_hash::{FxHashMap}; use std::sync::Arc; +fn sample_new_ivar( + original_pattern: &Pattern, + shared: &SharedData, + variable_ivar: usize, + match_loc: &usize, + rng: &mut impl rand::Rng, +) -> Option { + let num_vars = get_num_variables(original_pattern); + if num_vars <= 1 { + return None; // no other variable to expand to + } + let mut new_ivar = rng.gen_range(0..num_vars - 1); + if new_ivar >= variable_ivar { + new_ivar += 1; // skip the variable we are expanding + } + let zid_original = get_zid_for_ivar(original_pattern, variable_ivar); + let zid_new = get_zid_for_ivar(original_pattern, new_ivar); + if shared.arg_of_zid_node[zid_original][match_loc].shifted_id == shared.arg_of_zid_node[zid_new][match_loc].shifted_id { + return Some(new_ivar); + } + return None; +} + +fn sample_variable_reuse( + pattern: &Pattern, + shared: &SharedData, + variable_ivar: usize, + match_location: usize, + rng: &mut impl rand::Rng, +) -> Option<(Pattern, ExpandsTo)> { + if let Some(new_ivar) = sample_new_ivar(pattern, shared, variable_ivar, &match_location, rng) { + let zid_original = get_zid_for_ivar(pattern, variable_ivar); + let zid_new = get_zid_for_ivar(pattern, new_ivar); + let locs = compatible_locations( + shared, + pattern, + &shared.arg_of_zid_node[zid_original], + &shared.arg_of_zid_node[zid_new], + ); + if !locs.is_empty() { + let mut pattern = pattern.clone(); + pattern.match_locations = locs; + let expands_to = ExpandsTo::IVar(new_ivar as i32); + return Some((pattern, expands_to)); + } + } + None +} + +fn sample_syntactic_expansion( + original_pattern: &Pattern, + arg_of_loc: &FxHashMap, + match_location: usize, +) -> (Pattern, ExpandsTo) { + let mut pattern = original_pattern.clone(); + let expands_to = arg_of_loc[&match_location].expands_to.clone(); + pattern.match_locations.retain( + |loc| arg_of_loc[&loc].expands_to == expands_to + ); + return (pattern, expands_to); +} + +fn sample_expands_to( + original_pattern: &Pattern, + shared: &SharedData, + arg_of_loc: &FxHashMap, + match_location: usize, + variable_ivar: usize, + rng: &mut impl rand::Rng, +) -> (Pattern, ExpandsTo) { + if let Some(out) = sample_variable_reuse( + original_pattern, + shared, + variable_ivar, + match_location, + rng, + ) { + return out; + } + return sample_syntactic_expansion( + original_pattern, + arg_of_loc, + match_location, + ); +} + pub fn smc_expand( original_pattern: &Pattern, shared: &SharedData, @@ -26,14 +112,11 @@ pub fn smc_expand( // hole_idx, // );e let variable_zid = get_zid_for_ivar(original_pattern, variable_ivar); + // println!("Original variable zid for ivar={}: {}", variable_ivar, variable_zid); // println!("Variable ZID: {}", variable_zid); let arg_of_loc = &shared.arg_of_zid_node[variable_zid]; // println!("Argument of location: {:?}", arg_of_loc); - let expands_to = arg_of_loc[&match_location].expands_to.clone(); - let mut pattern = original_pattern.clone(); - pattern.match_locations.retain( - |loc| arg_of_loc[&loc].expands_to == expands_to - ); + let (pattern, expands_to) = sample_expands_to(original_pattern, shared, arg_of_loc, match_location, variable_ivar, rng); perform_expansion_variable( pattern, &shared,