diff --git a/implants/lib/eldritchv2/eldritch-core/src/ast.rs b/implants/lib/eldritchv2/eldritch-core/src/ast.rs index 33e4cb6be..67892cbcf 100644 --- a/implants/lib/eldritchv2/eldritch-core/src/ast.rs +++ b/implants/lib/eldritchv2/eldritch-core/src/ast.rs @@ -97,32 +97,95 @@ impl fmt::Debug for Value { impl PartialEq for Value { fn eq(&self, other: &Self) -> bool { - match (self, other) { + let mut visited = BTreeSet::new(); + self.eq_helper(other, &mut visited) + } +} + +impl Value { + fn eq_helper(&self, other: &Self, visited: &mut BTreeSet<(usize, usize)>) -> bool { + let p1 = match self { + Value::List(l) => Arc::as_ptr(l) as usize, + Value::Dictionary(d) => Arc::as_ptr(d) as usize, + Value::Set(s) => Arc::as_ptr(s) as usize, + _ => 0, + }; + let p2 = match other { + Value::List(l) => Arc::as_ptr(l) as usize, + Value::Dictionary(d) => Arc::as_ptr(d) as usize, + Value::Set(s) => Arc::as_ptr(s) as usize, + _ => 0, + }; + + if p1 != 0 && p2 != 0 { + let pair = if p1 < p2 { (p1, p2) } else { (p2, p1) }; + if visited.contains(&pair) { + return true; + } + visited.insert(pair); + } + + let result = match (self, other) { (Value::None, Value::None) => true, (Value::Bool(a), Value::Bool(b)) => a == b, (Value::Int(a), Value::Int(b)) => a == b, - (Value::Float(a), Value::Float(b)) => a == b, // Note: NaN != NaN usually, but handled by PartialOrd? No, PartialEq + (Value::Float(a), Value::Float(b)) => a == b, (Value::String(a), Value::String(b)) => a == b, (Value::Bytes(a), Value::Bytes(b)) => a == b, (Value::List(a), Value::List(b)) => { if Arc::ptr_eq(a, b) { - return true; + true + } else { + let la = a.read(); + let lb = b.read(); + if la.len() != lb.len() { + false + } else { + la.iter() + .zip(lb.iter()) + .all(|(va, vb)| va.eq_helper(vb, visited)) + } } - a.read().eq(&*b.read()) } (Value::Dictionary(a), Value::Dictionary(b)) => { if Arc::ptr_eq(a, b) { - return true; + true + } else { + let da = a.read(); + let db = b.read(); + if da.len() != db.len() { + false + } else { + da.iter().zip(db.iter()).all(|((ka, va), (kb, vb))| { + ka.eq_helper(kb, visited) && va.eq_helper(vb, visited) + }) + } } - a.read().eq(&*b.read()) } (Value::Set(a), Value::Set(b)) => { if Arc::ptr_eq(a, b) { - return true; + true + } else { + let sa = a.read(); + let sb = b.read(); + if sa.len() != sb.len() { + false + } else { + sa.iter() + .zip(sb.iter()) + .all(|(va, vb)| va.eq_helper(vb, visited)) + } + } + } + (Value::Tuple(a), Value::Tuple(b)) => { + if a.len() != b.len() { + false + } else { + a.iter() + .zip(b.iter()) + .all(|(va, vb)| va.eq_helper(vb, visited)) } - a.read().eq(&*b.read()) } - (Value::Tuple(a), Value::Tuple(b)) => a == b, (Value::Function(a), Value::Function(b)) => a.name == b.name, (Value::NativeFunction(a, _), Value::NativeFunction(b, _)) => a == b, (Value::NativeFunctionWithKwargs(a, _), Value::NativeFunctionWithKwargs(b, _)) => { @@ -131,7 +194,14 @@ impl PartialEq for Value { (Value::BoundMethod(r1, n1), Value::BoundMethod(r2, n2)) => r1 == r2 && n1 == n2, (Value::Foreign(a), Value::Foreign(b)) => Arc::ptr_eq(a, b), _ => false, + }; + + if p1 != 0 && p2 != 0 { + let pair = if p1 < p2 { (p1, p2) } else { (p2, p1) }; + visited.remove(&pair); } + + result } } @@ -145,16 +215,48 @@ impl PartialOrd for Value { impl Ord for Value { fn cmp(&self, other: &Self) -> Ordering { + let mut visited = BTreeSet::new(); + self.cmp_helper(other, &mut visited) + } +} + +impl Value { + fn cmp_helper(&self, other: &Self, visited: &mut BTreeSet<(usize, usize)>) -> Ordering { + let p1 = match self { + Value::List(l) => Arc::as_ptr(l) as usize, + Value::Dictionary(d) => Arc::as_ptr(d) as usize, + Value::Set(s) => Arc::as_ptr(s) as usize, + _ => 0, + }; + let p2 = match other { + Value::List(l) => Arc::as_ptr(l) as usize, + Value::Dictionary(d) => Arc::as_ptr(d) as usize, + Value::Set(s) => Arc::as_ptr(s) as usize, + _ => 0, + }; + + if p1 != 0 && p2 != 0 { + let pair = (p1, p2); + if visited.contains(&pair) { + return Ordering::Equal; + } + visited.insert(pair); + } + // Define an ordering between types: // None < Bool < Int < Float < String < Bytes < List < Tuple < Dict < Set < Function < Native < Bound < Foreign let self_discriminant = self.discriminant_value(); let other_discriminant = other.discriminant_value(); if self_discriminant != other_discriminant { + if p1 != 0 && p2 != 0 { + let pair = (p1, p2); + visited.remove(&pair); + } return self_discriminant.cmp(&other_discriminant); } - match (self, other) { + let result = match (self, other) { (Value::None, Value::None) => Ordering::Equal, (Value::Bool(a), Value::Bool(b)) => a.cmp(b), (Value::Int(a), Value::Int(b)) => a.cmp(b), @@ -163,27 +265,103 @@ impl Ord for Value { (Value::Bytes(a), Value::Bytes(b)) => a.cmp(b), (Value::List(a), Value::List(b)) => { if Arc::ptr_eq(a, b) { - return Ordering::Equal; + Ordering::Equal + } else { + let la = a.read(); + let lb = b.read(); + // Lexicographical comparison with recursion + let len = la.len().min(lb.len()); + let mut ord = Ordering::Equal; + for i in 0..len { + ord = la[i].cmp_helper(&lb[i], visited); + if ord != Ordering::Equal { + break; + } + } + if ord == Ordering::Equal { + la.len().cmp(&lb.len()) + } else { + ord + } + } + } + (Value::Tuple(a), Value::Tuple(b)) => { + let len = a.len().min(b.len()); + let mut ord = Ordering::Equal; + for i in 0..len { + ord = a[i].cmp_helper(&b[i], visited); + if ord != Ordering::Equal { + break; + } + } + if ord == Ordering::Equal { + a.len().cmp(&b.len()) + } else { + ord } - a.read().cmp(&*b.read()) } - (Value::Tuple(a), Value::Tuple(b)) => a.cmp(b), (Value::Dictionary(a), Value::Dictionary(b)) => { if Arc::ptr_eq(a, b) { - return Ordering::Equal; + Ordering::Equal + } else { + let da = a.read(); + let db = b.read(); + // Iterate and compare (key, value) pairs + let mut it1 = da.iter(); + let mut it2 = db.iter(); + loop { + match (it1.next(), it2.next()) { + (Some((k1, v1)), Some((k2, v2))) => { + let mut ord = k1.cmp_helper(k2, visited); + if ord == Ordering::Equal { + ord = v1.cmp_helper(v2, visited); + } + if ord != Ordering::Equal { + break ord; + } + } + (Some(_), None) => { + break Ordering::Greater; + } + (None, Some(_)) => { + break Ordering::Less; + } + (None, None) => { + break Ordering::Equal; + } + } + } } - // BTreeMap implements Ord - a.read().cmp(&*b.read()) } (Value::Set(a), Value::Set(b)) => { if Arc::ptr_eq(a, b) { - return Ordering::Equal; + Ordering::Equal + } else { + let sa = a.read(); + let sb = b.read(); + let mut it1 = sa.iter(); + let mut it2 = sb.iter(); + loop { + match (it1.next(), it2.next()) { + (Some(v1), Some(v2)) => { + let ord = v1.cmp_helper(v2, visited); + if ord != Ordering::Equal { + break ord; + } + } + (Some(_), None) => { + break Ordering::Greater; + } + (None, Some(_)) => { + break Ordering::Less; + } + (None, None) => { + break Ordering::Equal; + } + } + } } - // BTreeSet implements Ord - a.read().cmp(&*b.read()) } - // For functions and others, we just compare pointers or names as best effort - // This is primarily to satisfy BTreeSet requirement, not for user-facing logical ordering necessarily. (Value::Function(a), Value::Function(b)) => a.name.cmp(&b.name), (Value::NativeFunction(a, _), Value::NativeFunction(b, _)) => a.cmp(b), (Value::NativeFunctionWithKwargs(a, _), Value::NativeFunctionWithKwargs(b, _)) => { @@ -199,7 +377,14 @@ impl Ord for Value { p1.cmp(&p2) } _ => Ordering::Equal, // Should be covered by discriminant check + }; + + if p1 != 0 && p2 != 0 { + let pair = (p1, p2); + visited.remove(&pair); } + + result } } diff --git a/implants/lib/eldritchv2/eldritch-core/tests/regression_recursion.rs b/implants/lib/eldritchv2/eldritch-core/tests/regression_recursion.rs new file mode 100644 index 000000000..69b996739 --- /dev/null +++ b/implants/lib/eldritchv2/eldritch-core/tests/regression_recursion.rs @@ -0,0 +1,78 @@ +use eldritch_core::{Interpreter, Printer, Span}; +use std::sync::{Arc, Mutex}; + +// Mock printer to capture output +#[derive(Debug, Default)] +struct MockPrinter { + output: Arc>, +} + +impl Printer for MockPrinter { + fn print_out(&self, _span: &Span, msg: &str) { + let mut out = self.output.lock().unwrap(); + out.push_str(msg); + out.push('\n'); + } + fn print_err(&self, _span: &Span, msg: &str) { + let mut out = self.output.lock().unwrap(); + out.push_str("ERR: "); + out.push_str(msg); + out.push('\n'); + } +} + +#[test] +fn test_recursive_equality_deadlock() { + let mut interp = Interpreter::new(); + let printer = Arc::new(MockPrinter::default()); + interp.env.write().printer = printer.clone(); + + // a = [] + // b = [] + // a.append(b) + // b.append(a) + // print(a == b) + + // Note: a == b returns True with our cycle detection (assume equal until proven otherwise). + // The main goal is to ensure this DOES NOT stack overflow. + + let code = r#" +a = [] +b = [] +a.append(b) +b.append(a) +print("Comparing...") +x = (a == b) +print(x) +"#; + + let result = interp.interpret(code); + if let Err(e) = result { + panic!("Interpreter failed: {:?}", e); + } + + let output = printer.output.lock().unwrap(); + assert!(output.contains("True")); +} + +#[test] +fn test_recursive_list_print() { + let mut interp = Interpreter::new(); + let printer = Arc::new(MockPrinter::default()); + interp.env.write().printer = printer.clone(); + + let code = r#" +a = [] +for i in range(1, 10): + a.append(a) +print(a) +"#; + + let result = interp.interpret(code); + if let Err(e) = result { + panic!("Interpreter failed: {:?}", e); + } + + let output = printer.output.lock().unwrap(); + assert!(output.contains("[...]")); +}