Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 206 additions & 21 deletions implants/lib/eldritchv2/eldritch-core/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,32 +97,95 @@ impl fmt::Debug for Value {

impl PartialEq for Value {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
let mut visited = BTreeSet::new();
self.eq_helper(other, &mut visited)
}
}

impl Value {
fn eq_helper(&self, other: &Self, visited: &mut BTreeSet<(usize, usize)>) -> bool {
let p1 = match self {
Value::List(l) => Arc::as_ptr(l) as usize,
Value::Dictionary(d) => Arc::as_ptr(d) as usize,
Value::Set(s) => Arc::as_ptr(s) as usize,
_ => 0,
};
let p2 = match other {
Value::List(l) => Arc::as_ptr(l) as usize,
Value::Dictionary(d) => Arc::as_ptr(d) as usize,
Value::Set(s) => Arc::as_ptr(s) as usize,
_ => 0,
};

if p1 != 0 && p2 != 0 {
let pair = if p1 < p2 { (p1, p2) } else { (p2, p1) };
if visited.contains(&pair) {
return true;
}
visited.insert(pair);
}

let result = match (self, other) {
(Value::None, Value::None) => true,
(Value::Bool(a), Value::Bool(b)) => a == b,
(Value::Int(a), Value::Int(b)) => a == b,
(Value::Float(a), Value::Float(b)) => a == b, // Note: NaN != NaN usually, but handled by PartialOrd? No, PartialEq
(Value::Float(a), Value::Float(b)) => a == b,
(Value::String(a), Value::String(b)) => a == b,
(Value::Bytes(a), Value::Bytes(b)) => a == b,
(Value::List(a), Value::List(b)) => {
if Arc::ptr_eq(a, b) {
return true;
true
} else {
let la = a.read();
let lb = b.read();
if la.len() != lb.len() {
false
} else {
la.iter()
.zip(lb.iter())
.all(|(va, vb)| va.eq_helper(vb, visited))
}
}
a.read().eq(&*b.read())
}
(Value::Dictionary(a), Value::Dictionary(b)) => {
if Arc::ptr_eq(a, b) {
return true;
true
} else {
let da = a.read();
let db = b.read();
if da.len() != db.len() {
false
} else {
da.iter().zip(db.iter()).all(|((ka, va), (kb, vb))| {
ka.eq_helper(kb, visited) && va.eq_helper(vb, visited)
})
}
}
a.read().eq(&*b.read())
}
(Value::Set(a), Value::Set(b)) => {
if Arc::ptr_eq(a, b) {
return true;
true
} else {
let sa = a.read();
let sb = b.read();
if sa.len() != sb.len() {
false
} else {
sa.iter()
.zip(sb.iter())
.all(|(va, vb)| va.eq_helper(vb, visited))
}
}
}
(Value::Tuple(a), Value::Tuple(b)) => {
if a.len() != b.len() {
false
} else {
a.iter()
.zip(b.iter())
.all(|(va, vb)| va.eq_helper(vb, visited))
}
a.read().eq(&*b.read())
}
(Value::Tuple(a), Value::Tuple(b)) => a == b,
(Value::Function(a), Value::Function(b)) => a.name == b.name,
(Value::NativeFunction(a, _), Value::NativeFunction(b, _)) => a == b,
(Value::NativeFunctionWithKwargs(a, _), Value::NativeFunctionWithKwargs(b, _)) => {
Expand All @@ -131,7 +194,14 @@ impl PartialEq for Value {
(Value::BoundMethod(r1, n1), Value::BoundMethod(r2, n2)) => r1 == r2 && n1 == n2,
(Value::Foreign(a), Value::Foreign(b)) => Arc::ptr_eq(a, b),
_ => false,
};

if p1 != 0 && p2 != 0 {
let pair = if p1 < p2 { (p1, p2) } else { (p2, p1) };
visited.remove(&pair);
}

result
}
}

Expand All @@ -145,16 +215,48 @@ impl PartialOrd for Value {

impl Ord for Value {
fn cmp(&self, other: &Self) -> Ordering {
let mut visited = BTreeSet::new();
self.cmp_helper(other, &mut visited)
}
}

impl Value {
fn cmp_helper(&self, other: &Self, visited: &mut BTreeSet<(usize, usize)>) -> Ordering {
let p1 = match self {
Value::List(l) => Arc::as_ptr(l) as usize,
Value::Dictionary(d) => Arc::as_ptr(d) as usize,
Value::Set(s) => Arc::as_ptr(s) as usize,
_ => 0,
};
let p2 = match other {
Value::List(l) => Arc::as_ptr(l) as usize,
Value::Dictionary(d) => Arc::as_ptr(d) as usize,
Value::Set(s) => Arc::as_ptr(s) as usize,
_ => 0,
};

if p1 != 0 && p2 != 0 {
let pair = (p1, p2);
if visited.contains(&pair) {
return Ordering::Equal;
}
visited.insert(pair);
}

// Define an ordering between types:
// None < Bool < Int < Float < String < Bytes < List < Tuple < Dict < Set < Function < Native < Bound < Foreign
let self_discriminant = self.discriminant_value();
let other_discriminant = other.discriminant_value();

if self_discriminant != other_discriminant {
if p1 != 0 && p2 != 0 {
let pair = (p1, p2);
visited.remove(&pair);
}
return self_discriminant.cmp(&other_discriminant);
}

match (self, other) {
let result = match (self, other) {
(Value::None, Value::None) => Ordering::Equal,
(Value::Bool(a), Value::Bool(b)) => a.cmp(b),
(Value::Int(a), Value::Int(b)) => a.cmp(b),
Expand All @@ -163,27 +265,103 @@ impl Ord for Value {
(Value::Bytes(a), Value::Bytes(b)) => a.cmp(b),
(Value::List(a), Value::List(b)) => {
if Arc::ptr_eq(a, b) {
return Ordering::Equal;
Ordering::Equal
} else {
let la = a.read();
let lb = b.read();
// Lexicographical comparison with recursion
let len = la.len().min(lb.len());
let mut ord = Ordering::Equal;
for i in 0..len {
ord = la[i].cmp_helper(&lb[i], visited);
if ord != Ordering::Equal {
break;
}
}
if ord == Ordering::Equal {
la.len().cmp(&lb.len())
} else {
ord
}
}
}
(Value::Tuple(a), Value::Tuple(b)) => {
let len = a.len().min(b.len());
let mut ord = Ordering::Equal;
for i in 0..len {
ord = a[i].cmp_helper(&b[i], visited);
if ord != Ordering::Equal {
break;
}
}
if ord == Ordering::Equal {
a.len().cmp(&b.len())
} else {
ord
}
a.read().cmp(&*b.read())
}
(Value::Tuple(a), Value::Tuple(b)) => a.cmp(b),
(Value::Dictionary(a), Value::Dictionary(b)) => {
if Arc::ptr_eq(a, b) {
return Ordering::Equal;
Ordering::Equal
} else {
let da = a.read();
let db = b.read();
// Iterate and compare (key, value) pairs
let mut it1 = da.iter();
let mut it2 = db.iter();
loop {
match (it1.next(), it2.next()) {
(Some((k1, v1)), Some((k2, v2))) => {
let mut ord = k1.cmp_helper(k2, visited);
if ord == Ordering::Equal {
ord = v1.cmp_helper(v2, visited);
}
if ord != Ordering::Equal {
break ord;
}
}
(Some(_), None) => {
break Ordering::Greater;
}
(None, Some(_)) => {
break Ordering::Less;
}
(None, None) => {
break Ordering::Equal;
}
}
}
}
// BTreeMap implements Ord
a.read().cmp(&*b.read())
}
(Value::Set(a), Value::Set(b)) => {
if Arc::ptr_eq(a, b) {
return Ordering::Equal;
Ordering::Equal
} else {
let sa = a.read();
let sb = b.read();
let mut it1 = sa.iter();
let mut it2 = sb.iter();
loop {
match (it1.next(), it2.next()) {
(Some(v1), Some(v2)) => {
let ord = v1.cmp_helper(v2, visited);
if ord != Ordering::Equal {
break ord;
}
}
(Some(_), None) => {
break Ordering::Greater;
}
(None, Some(_)) => {
break Ordering::Less;
}
(None, None) => {
break Ordering::Equal;
}
}
}
}
// BTreeSet implements Ord
a.read().cmp(&*b.read())
}
// For functions and others, we just compare pointers or names as best effort
// This is primarily to satisfy BTreeSet requirement, not for user-facing logical ordering necessarily.
(Value::Function(a), Value::Function(b)) => a.name.cmp(&b.name),
(Value::NativeFunction(a, _), Value::NativeFunction(b, _)) => a.cmp(b),
(Value::NativeFunctionWithKwargs(a, _), Value::NativeFunctionWithKwargs(b, _)) => {
Expand All @@ -199,7 +377,14 @@ impl Ord for Value {
p1.cmp(&p2)
}
_ => Ordering::Equal, // Should be covered by discriminant check
};

if p1 != 0 && p2 != 0 {
let pair = (p1, p2);
visited.remove(&pair);
}

result
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use eldritch_core::{Interpreter, Printer, Span};
use std::sync::{Arc, Mutex};

// Mock printer to capture output
#[derive(Debug, Default)]
struct MockPrinter {
output: Arc<Mutex<String>>,
}

impl Printer for MockPrinter {
fn print_out(&self, _span: &Span, msg: &str) {
let mut out = self.output.lock().unwrap();
out.push_str(msg);
out.push('\n');
}
fn print_err(&self, _span: &Span, msg: &str) {
let mut out = self.output.lock().unwrap();
out.push_str("ERR: ");
out.push_str(msg);
out.push('\n');
}
}

#[test]
fn test_recursive_equality_deadlock() {
let mut interp = Interpreter::new();
let printer = Arc::new(MockPrinter::default());
interp.env.write().printer = printer.clone();

// a = []
// b = []
// a.append(b)
// b.append(a)
// print(a == b)

// Note: a == b returns True with our cycle detection (assume equal until proven otherwise).
// The main goal is to ensure this DOES NOT stack overflow.

let code = r#"
a = []
b = []
a.append(b)
b.append(a)
print("Comparing...")
x = (a == b)
print(x)
"#;

let result = interp.interpret(code);
if let Err(e) = result {
panic!("Interpreter failed: {:?}", e);
}

let output = printer.output.lock().unwrap();
assert!(output.contains("True"));
}

#[test]
fn test_recursive_list_print() {
let mut interp = Interpreter::new();
let printer = Arc::new(MockPrinter::default());
interp.env.write().printer = printer.clone();

let code = r#"
a = []
for i in range(1, 10):
a.append(a)
print(a)
"#;

let result = interp.interpret(code);
if let Err(e) = result {
panic!("Interpreter failed: {:?}", e);
}

let output = printer.output.lock().unwrap();
assert!(output.contains("[...]"));
}