Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,15 @@ fn get_ivars_expansions(original_pattern: &Pattern, arg_of_loc: &FxHashMap<Idx,A
ivars_expansions
}

pub fn compatible_locations(shared: &SharedData, original_pattern: &Pattern, arg_of_loc_1: &FxHashMap<Idx,Arg>, arg_of_loc_2: &FxHashMap<Idx,Arg>) -> Vec<usize> {
let locs: Vec<Idx> = original_pattern.match_locations.iter()
.filter(|loc:&&Idx| arg_of_loc_1[loc].shifted_id ==
arg_of_loc_2[loc].shifted_id
&& !invalid_metavar_location(shared, arg_of_loc_1[loc].shifted_id)
).cloned().collect();
locs
}


/// A finished abstraction
#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down
47 changes: 40 additions & 7 deletions src/expand_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub fn add_variable_at(p: &mut Pattern, at_loc: usize, var_id: i32) {
}
}

pub fn remove_variable_at(p: &mut Pattern, var_id: usize) -> Vec<ZId> {
pub fn remove_variable_at(p: &mut Pattern, var_id: usize, expands_to: &mut ExpandsTo) -> Vec<ZId> {
let mut zids = Vec::new();
// remove the variable from the arg choices
p.arg_choices.retain(|x: &LabelledZId| {
Expand All @@ -24,6 +24,12 @@ pub fn remove_variable_at(p: &mut Pattern, var_id: usize) -> Vec<ZId> {
x.ivar -= 1; // decrement the ivar index for all variables after the one we're removing
}
});
if let ExpandsTo::IVar(i) = expands_to {
assert!(*i != var_id as i32, "ExpandsTo::IVar should not be the variable we're removing");
if *i > var_id as i32 {
*i -= 1;
}
}
// remove the variable from the first_zid_of_ivar
p.first_zid_of_ivar.remove(var_id);
zids
Expand All @@ -46,7 +52,25 @@ pub fn check_consistency(shared: &SharedData, p: &Pattern) {
// check that the ivar is within bounds
let arg_of_loc = &shared.arg_of_zid_node[labeled.zid];
for loc in p.match_locations.iter() {
assert!(arg_of_loc.contains_key(loc), "Variable {} at location {} is not consistent with shared data", labeled.ivar, loc);
assert!(arg_of_loc.contains_key(loc), "Variable id={}, zid={} at location {} is not consistent with shared data", labeled.ivar, labeled.zid, loc);
}
}
for (ivar, zid) in p.first_zid_of_ivar.iter().enumerate() {
for labeled in p.arg_choices.iter() {
if labeled.ivar == ivar {
// check that they expand to the same thing
let arg_of_loc_1 = &shared.arg_of_zid_node[labeled.zid];
let arg_of_loc_2 = &shared.arg_of_zid_node[*zid];
// println!("Checking consistency for variable id={} (zid={} vs zid={})", ivar, zid, labeled.zid);
for loc in p.match_locations.iter() {
assert!(arg_of_loc_1.contains_key(loc) && arg_of_loc_2.contains_key(loc),
"Variable id={} at location {} is not consistent with shared data: {:?} vs {:?}", ivar, loc, arg_of_loc_1.get(loc), arg_of_loc_2.get(loc));
assert_eq!(arg_of_loc_1[loc].shifted_id, arg_of_loc_2[loc].shifted_id,
"Variable id={} at location {} has different shifted ids: {} vs {}", ivar, loc, arg_of_loc_1[loc].shifted_id, arg_of_loc_2[loc].shifted_id);
assert_eq!(arg_of_loc_1[loc].expands_to, arg_of_loc_2[loc].expands_to,
"Variable id={} at location {} expands to different things: {} vs {}", ivar, loc, arg_of_loc_1[loc].expands_to, arg_of_loc_2[loc].expands_to);
}
}
}
}
// for i in 0..num_vars {
Expand All @@ -71,11 +95,13 @@ pub fn perform_expansion_variable(
// if tracked { found_tracked = true; }
// if shared.cfg.follow_prune && !tracked { return None; }
let mut pattern = pattern;
let mut expands_to = expands_to;


// check_consistency(shared, original_pattern);
// check_consistency(shared, &pattern);
// println!("expands_to: {:?}", expands_to);
// println!("pattern: {:?}", pattern);
// update the body utility
let body_utility = pattern.body_utility + compute_body_utility_change(shared, &expands_to);

// assert!(shared.cfg.no_opt_upper_bound || !holes_after_pop.is_empty() || !original_pattern.arg_choices.is_empty() || expands_to.has_holes() || expands_to.is_ivar(),
// "unexpected arity 0 invention: upper bounds + priming with arity 0 inventions should have prevented this");
Expand All @@ -85,12 +111,18 @@ pub fn perform_expansion_variable(
// build our new pattern with all the variables we've just defined. Copy in the argchoices and prefixes
// from the old pattern.
// new_pattern.match_locations = locs;
pattern.body_utility = body_utility;

// println!("targeting ivar: {}", variable_ivar);
// println!("expands_to: {:?}", expands_to);
// println!("original: {:?}", new_pattern);
let variable_zids: Vec<usize> = remove_variable_at(&mut pattern, variable_ivar);
let variable_zids: Vec<usize> = remove_variable_at(&mut pattern, variable_ivar, &mut expands_to);

let body_utility = pattern.body_utility + compute_body_utility_change(shared, &expands_to) * variable_zids.len() as i32;
pattern.body_utility = body_utility;

// println!("pattern after remove: {:?}", pattern);
// println!("expands_to after remove: {:?}", expands_to);
// println!("variable_zids: {:?}", variable_zids);
// println!("after removing variable: {:?}", new_pattern);
// println!("variable_zids: {:?}", variable_zids);

Expand All @@ -115,10 +147,11 @@ pub fn perform_expansion_variable(
add_variable_at(&mut pattern, variable_zid, i);
}
}
// println!("pattern after add: {:?}", pattern);

// println!("after adding variable: {:?}", new_pattern);

// check_consistency(shared, &new_pattern);
// check_consistency(shared, &pattern);


Some (pattern)
Expand Down
93 changes: 88 additions & 5 deletions src/smc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,92 @@ use rand::SeedableRng;
use rustc_hash::{FxHashMap};
use std::sync::Arc;

fn sample_new_ivar(
original_pattern: &Pattern,
shared: &SharedData,
variable_ivar: usize,
match_loc: &usize,
rng: &mut impl rand::Rng,
) -> Option<usize> {
let num_vars = get_num_variables(original_pattern);
if num_vars <= 1 {
return None; // no other variable to expand to
}
let mut new_ivar = rng.gen_range(0..num_vars - 1);
if new_ivar >= variable_ivar {
new_ivar += 1; // skip the variable we are expanding
}
let zid_original = get_zid_for_ivar(original_pattern, variable_ivar);
let zid_new = get_zid_for_ivar(original_pattern, new_ivar);
if shared.arg_of_zid_node[zid_original][match_loc].shifted_id == shared.arg_of_zid_node[zid_new][match_loc].shifted_id {
return Some(new_ivar);
}
return None;
}

fn sample_variable_reuse(
pattern: &Pattern,
shared: &SharedData,
variable_ivar: usize,
match_location: usize,
rng: &mut impl rand::Rng,
) -> Option<(Pattern, ExpandsTo)> {
if let Some(new_ivar) = sample_new_ivar(pattern, shared, variable_ivar, &match_location, rng) {
let zid_original = get_zid_for_ivar(pattern, variable_ivar);
let zid_new = get_zid_for_ivar(pattern, new_ivar);
let locs = compatible_locations(
shared,
pattern,
&shared.arg_of_zid_node[zid_original],
&shared.arg_of_zid_node[zid_new],
);
if !locs.is_empty() {
let mut pattern = pattern.clone();
pattern.match_locations = locs;
let expands_to = ExpandsTo::IVar(new_ivar as i32);
return Some((pattern, expands_to));
}
}
None
}

fn sample_syntactic_expansion(
original_pattern: &Pattern,
arg_of_loc: &FxHashMap<Idx, Arg>,
match_location: usize,
) -> (Pattern, ExpandsTo) {
let mut pattern = original_pattern.clone();
let expands_to = arg_of_loc[&match_location].expands_to.clone();
pattern.match_locations.retain(
|loc| arg_of_loc[&loc].expands_to == expands_to
);
return (pattern, expands_to);
}

fn sample_expands_to(
original_pattern: &Pattern,
shared: &SharedData,
arg_of_loc: &FxHashMap<Idx,Arg>,
match_location: usize,
variable_ivar: usize,
rng: &mut impl rand::Rng,
) -> (Pattern, ExpandsTo) {
if let Some(out) = sample_variable_reuse(
original_pattern,
shared,
variable_ivar,
match_location,
rng,
) {
return out;
}
return sample_syntactic_expansion(
original_pattern,
arg_of_loc,
match_location,
);
}

pub fn smc_expand(
original_pattern: &Pattern,
shared: &SharedData,
Expand All @@ -26,14 +112,11 @@ pub fn smc_expand(
// hole_idx,
// );e
let variable_zid = get_zid_for_ivar(original_pattern, variable_ivar);
// println!("Original variable zid for ivar={}: {}", variable_ivar, variable_zid);
// println!("Variable ZID: {}", variable_zid);
let arg_of_loc = &shared.arg_of_zid_node[variable_zid];
// println!("Argument of location: {:?}", arg_of_loc);
let expands_to = arg_of_loc[&match_location].expands_to.clone();
let mut pattern = original_pattern.clone();
pattern.match_locations.retain(
|loc| arg_of_loc[&loc].expands_to == expands_to
);
let (pattern, expands_to) = sample_expands_to(original_pattern, shared, arg_of_loc, match_location, variable_ivar, rng);
perform_expansion_variable(
pattern,
&shared,
Expand Down