diff --git a/crates/graphrecords-core/src/graphrecord/datatypes/value.rs b/crates/graphrecords-core/src/graphrecord/datatypes/value.rs index 286f6ce..ec47444 100644 --- a/crates/graphrecords-core/src/graphrecord/datatypes/value.rs +++ b/crates/graphrecords-core/src/graphrecord/datatypes/value.rs @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize}; use std::{ cmp::Ordering, fmt::Display, + hash::{Hash, Hasher}, ops::{Add, Div, Mul, Range, Sub}, }; @@ -52,15 +53,38 @@ where } } -// TODO: Add tests for Duration +fn canonicalize_float(value: f64) -> f64 { + if value.is_nan() { + f64::NAN + } else if value == 0.0 { + 0.0_f64 + } else { + value + } +} + +fn int_float_eq(int_value: i64, float_value: f64) -> bool { + let converted = int_value as f64; + + converted == float_value && converted as i64 == int_value +} + impl PartialEq for GraphRecordValue { fn eq(&self, other: &Self) -> bool { match (self, other) { (Self::String(value), Self::String(other)) => value == other, (Self::Int(value), Self::Int(other)) => value == other, - (Self::Int(value), Self::Float(other)) => &(*value as f64) == other, - (Self::Float(value), Self::Float(other)) => value == other, - (Self::Float(value), Self::Int(other)) => value == &(*other as f64), + (Self::Int(int_value), Self::Float(float_value)) + | (Self::Float(float_value), Self::Int(int_value)) => { + int_float_eq(*int_value, *float_value) + } + (Self::Float(value), Self::Float(other)) => { + if value.is_nan() { + other.is_nan() + } else { + value == other + } + } (Self::Bool(value), Self::Bool(other)) => value == other, (Self::DateTime(value), Self::DateTime(other)) => value == other, (Self::Duration(value), Self::Duration(other)) => value == other, @@ -70,6 +94,67 @@ impl PartialEq for GraphRecordValue { } } +impl Eq for GraphRecordValue {} + +impl GraphRecordValue { + const fn variant_rank(&self) -> u8 { + match self { + Self::Null => 0, + Self::Bool(_) => 1, + Self::Int(_) | Self::Float(_) => 2, + Self::String(_) => 3, + Self::DateTime(_) => 4, + Self::Duration(_) => 5, + } + } + + #[must_use] + pub fn total_cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (Self::String(value), Self::String(other)) => value.cmp(other), + (Self::Int(value), Self::Int(other)) => value.cmp(other), + (Self::Int(value), Self::Float(other)) => (*value as f64).total_cmp(other), + (Self::Float(value), Self::Int(other)) => value.total_cmp(&(*other as f64)), + (Self::Float(value), Self::Float(other)) => value.total_cmp(other), + (Self::Bool(value), Self::Bool(other)) => value.cmp(other), + (Self::DateTime(value), Self::DateTime(other)) => value.cmp(other), + (Self::Duration(value), Self::Duration(other)) => value.cmp(other), + (Self::Null, Self::Null) => Ordering::Equal, + _ => self.variant_rank().cmp(&other.variant_rank()), + } + } + + fn hash_discriminant(&self, state: &mut H) { + match self { + Self::Int(_) | Self::Float(_) => 0_u8.hash(state), + Self::String(_) => 1_u8.hash(state), + Self::Bool(_) => 2_u8.hash(state), + Self::DateTime(_) => 3_u8.hash(state), + Self::Duration(_) => 4_u8.hash(state), + Self::Null => 5_u8.hash(state), + } + } +} + +impl Hash for GraphRecordValue { + fn hash(&self, state: &mut H) { + self.hash_discriminant(state); + match self { + Self::Int(value) => { + canonicalize_float(*value as f64).to_bits().hash(state); + } + Self::Float(value) => { + canonicalize_float(*value).to_bits().hash(state); + } + Self::String(value) => value.hash(state), + Self::Bool(value) => value.hash(state), + Self::DateTime(value) => value.hash(state), + Self::Duration(value) => value.hash(state), + Self::Null => {} + } + } +} + // TODO: Add tests for Duration impl PartialOrd for GraphRecordValue { fn partial_cmp(&self, other: &Self) -> Option { @@ -1193,6 +1278,7 @@ mod test { }, }; use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; + use std::hash::{Hash, Hasher}; #[test] fn test_default() { @@ -1270,6 +1356,7 @@ mod test { assert!(GraphRecordValue::Int(0) == GraphRecordValue::Float(0_f64)); assert!(GraphRecordValue::Int(1) != GraphRecordValue::Float(0_f64)); + assert!(GraphRecordValue::Int(1) == GraphRecordValue::Float(1_f64)); assert!(GraphRecordValue::Float(0_f64) == GraphRecordValue::Float(0_f64)); assert!(GraphRecordValue::Float(1_f64) != GraphRecordValue::Float(0_f64)); @@ -1277,6 +1364,13 @@ mod test { assert!(GraphRecordValue::Float(0_f64) == GraphRecordValue::Int(0)); assert!(GraphRecordValue::Float(1_f64) != GraphRecordValue::Int(0)); + assert!(GraphRecordValue::Float(f64::NAN) == GraphRecordValue::Float(f64::NAN)); + assert!(GraphRecordValue::Float(-0.0) == GraphRecordValue::Float(0.0)); + + let large_int = (1_i64 << 53) + 1; + assert!(GraphRecordValue::Int(large_int) != GraphRecordValue::Float(large_int as f64)); + assert!(GraphRecordValue::Float(large_int as f64) != GraphRecordValue::Int(large_int)); + assert!(GraphRecordValue::Bool(false) == GraphRecordValue::Bool(false)); assert!(GraphRecordValue::Bool(true) != GraphRecordValue::Bool(false)); @@ -3189,4 +3283,59 @@ mod test { assert_eq!(GraphRecordValue::Null, GraphRecordValue::Null.uppercase()); } + + #[test] + fn test_hash() { + use std::collections::hash_map::DefaultHasher; + + let hash = |value: GraphRecordValue| -> u64 { + let mut hasher = DefaultHasher::new(); + value.hash(&mut hasher); + hasher.finish() + }; + + assert_eq!( + hash(GraphRecordValue::Int(1)), + hash(GraphRecordValue::Float(1.0)) + ); + assert_eq!( + hash(GraphRecordValue::Int(0)), + hash(GraphRecordValue::Float(0.0)) + ); + assert_eq!( + hash(GraphRecordValue::Float(-0.0)), + hash(GraphRecordValue::Float(0.0)) + ); + assert_eq!( + hash(GraphRecordValue::Float(f64::NAN)), + hash(GraphRecordValue::Float(f64::NAN)) + ); + assert_eq!(hash(GraphRecordValue::Null), hash(GraphRecordValue::Null)); + + assert_ne!( + hash(GraphRecordValue::Int(1)), + hash(GraphRecordValue::String("1".to_string())) + ); + assert_ne!( + hash(GraphRecordValue::Int(0)), + hash(GraphRecordValue::Bool(false)) + ); + } + + #[test] + fn test_eq_transitivity() { + let large_int = (1_i64 << 53) + 1; + let large_float = large_int as f64; + let rounded_int = large_float as i64; + + assert_ne!(large_int, rounded_int); + + let a = GraphRecordValue::Int(large_int); + let b = GraphRecordValue::Float(large_float); + let c = GraphRecordValue::Int(rounded_int); + + assert_ne!(a, b); + assert_eq!(b, c); + assert_ne!(a, c); + } } diff --git a/crates/graphrecords-core/src/graphrecord/overview/mod.rs b/crates/graphrecords-core/src/graphrecord/overview/mod.rs index 5609747..7e78dd4 100644 --- a/crates/graphrecords-core/src/graphrecord/overview/mod.rs +++ b/crates/graphrecords-core/src/graphrecord/overview/mod.rs @@ -14,7 +14,6 @@ use graphrecords_utils::aliases::GrHashMap; use itertools::Itertools; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use std::{ - cmp::Ordering, collections::HashSet, fmt::{Display, Formatter}, }; @@ -165,7 +164,7 @@ impl NodeGroupOverview { }) .evaluate()? .map(|(_, value)| value) - .sorted_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)) + .sorted_by(GraphRecordValue::total_cmp) .dedup_by(|a, b| a == b) .collect(); @@ -234,7 +233,7 @@ impl NodeGroupOverview { .evaluate() .unwrap() .map(|(_, value)| value) - .sorted_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)) + .sorted_by(GraphRecordValue::total_cmp) .dedup_by(|a, b| a == b) .count(); @@ -334,7 +333,7 @@ impl EdgeGroupOverview { }) .evaluate()? .map(|(_, value)| value) - .sorted_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)) + .sorted_by(GraphRecordValue::total_cmp) .dedup_by(|a, b| a == b) .collect(); @@ -401,7 +400,7 @@ impl EdgeGroupOverview { }) .evaluate()? .map(|(_, value)| value) - .sorted_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)) + .sorted_by(GraphRecordValue::total_cmp) .dedup_by(|a, b| a == b) .count();