From 3aefe84e67dc0148ff85dd444d5f9adbfc4304e2 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Tue, 27 Oct 2020 09:01:26 +0100 Subject: [PATCH 1/4] Moved JsonEqual to its own module. --- rust/arrow/src/array/array.rs | 2 +- rust/arrow/src/array/equal.rs | 920 --------------------------- rust/arrow/src/array/equal_json.rs | 985 +++++++++++++++++++++++++++++ rust/arrow/src/array/mod.rs | 3 +- 4 files changed, 988 insertions(+), 922 deletions(-) create mode 100644 rust/arrow/src/array/equal_json.rs diff --git a/rust/arrow/src/array/array.rs b/rust/arrow/src/array/array.rs index f3cc6b5fc60..e0863a09648 100644 --- a/rust/arrow/src/array/array.rs +++ b/rust/arrow/src/array/array.rs @@ -29,7 +29,7 @@ use num::Num; use super::*; use crate::array::builder::StringDictionaryBuilder; -use crate::array::equal::JsonEqual; +use crate::array::equal_json::JsonEqual; use crate::buffer::{buffer_bin_or, Buffer, MutableBuffer}; use crate::datatypes::DataType::Struct; use crate::datatypes::*; diff --git a/rust/arrow/src/array/equal.rs b/rust/arrow/src/array/equal.rs index 9a0b7e6b053..4c549286bff 100644 --- a/rust/arrow/src/array/equal.rs +++ b/rust/arrow/src/array/equal.rs @@ -22,9 +22,6 @@ use array::{ Array, BinaryOffsetSizeTrait, GenericBinaryArray, GenericListArray, GenericStringArray, ListArrayOps, OffsetSizeTrait, StringOffsetSizeTrait, }; -use hex::FromHex; -use serde_json::value::Value::{Null as JNull, Object, String as JString}; -use serde_json::Value; /// Trait for `Array` equality. pub trait ArrayEqual { @@ -829,345 +826,6 @@ fn value_offset_equal>( true } -/// Trait for comparing arrow array with json array -pub trait JsonEqual { - /// Checks whether arrow array equals to json array. - fn equals_json(&self, json: &[&Value]) -> bool; - - /// Checks whether arrow array equals to json array. - fn equals_json_values(&self, json: &[Value]) -> bool { - let refs = json.iter().collect::>(); - - self.equals_json(&refs) - } -} - -/// Implement array equals for numeric type -impl JsonEqual for PrimitiveArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - Value::Null => self.is_null(i), - v => self.is_valid(i) && Some(v) == self.value(i).into_json_value().as_ref(), - }) - } -} - -impl PartialEq for PrimitiveArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(array) => self.equals_json_values(&array), - _ => false, - } - } -} - -impl PartialEq> for Value { - fn eq(&self, arrow: &PrimitiveArray) -> bool { - match self { - Value::Array(array) => arrow.equals_json_values(&array), - _ => false, - } - } -} - -impl JsonEqual for GenericListArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - Value::Array(v) => self.is_valid(i) && self.value(i).equals_json_values(v), - Value::Null => self.is_null(i) || self.value_length(i).is_zero(), - _ => false, - }) - } -} - -impl PartialEq for GenericListArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq> for Value { - fn eq(&self, arrow: &GenericListArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for DictionaryArray { - fn equals_json(&self, json: &[&Value]) -> bool { - // todo: this is wrong: we must test the values also - self.keys().equals_json(json) - } -} - -impl PartialEq for DictionaryArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq> for Value { - fn eq(&self, arrow: &DictionaryArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for FixedSizeListArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - Value::Array(v) => self.is_valid(i) && self.value(i).equals_json_values(v), - Value::Null => self.is_null(i) || self.value_length() == 0, - _ => false, - }) - } -} - -impl PartialEq for FixedSizeListArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq for Value { - fn eq(&self, arrow: &FixedSizeListArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for StructArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - let all_object = json.iter().all(|v| match v { - Object(_) | JNull => true, - _ => false, - }); - - if !all_object { - return false; - } - - for column_name in self.column_names() { - let json_values = json - .iter() - .map(|obj| obj.get(column_name).unwrap_or(&Value::Null)) - .collect::>(); - - if !self - .column_by_name(column_name) - .map(|arr| arr.equals_json(&json_values)) - .unwrap_or(false) - { - return false; - } - } - - true - } -} - -impl PartialEq for StructArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(&json_array), - _ => false, - } - } -} - -impl PartialEq for Value { - fn eq(&self, arrow: &StructArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(&json_array), - _ => false, - } - } -} - -impl JsonEqual for GenericBinaryArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - JString(s) => { - // binary data is sometimes hex encoded, this checks if bytes are equal, - // and if not converting to hex is attempted - self.is_valid(i) - && (s.as_str().as_bytes() == self.value(i) - || Vec::from_hex(s.as_str()) == Ok(self.value(i).to_vec())) - } - JNull => self.is_null(i), - _ => false, - }) - } -} - -impl PartialEq - for GenericBinaryArray -{ - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(&json_array), - _ => false, - } - } -} - -impl PartialEq> - for Value -{ - fn eq(&self, arrow: &GenericBinaryArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(&json_array), - _ => false, - } - } -} - -impl JsonEqual for GenericStringArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - JString(s) => self.is_valid(i) && s.as_str() == self.value(i), - JNull => self.is_null(i), - _ => false, - }) - } -} - -impl PartialEq - for GenericStringArray -{ - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(&json_array), - _ => false, - } - } -} - -impl PartialEq> - for Value -{ - fn eq(&self, arrow: &GenericStringArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(&json_array), - _ => false, - } - } -} - -impl JsonEqual for FixedSizeBinaryArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - JString(s) => { - // binary data is sometimes hex encoded, this checks if bytes are equal, - // and if not converting to hex is attempted - self.is_valid(i) - && (s.as_str().as_bytes() == self.value(i) - || Vec::from_hex(s.as_str()) == Ok(self.value(i).to_vec())) - } - JNull => self.is_null(i), - _ => false, - }) - } -} - -impl PartialEq for FixedSizeBinaryArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(&json_array), - _ => false, - } - } -} - -impl PartialEq for Value { - fn eq(&self, arrow: &FixedSizeBinaryArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(&json_array), - _ => false, - } - } -} - -impl JsonEqual for UnionArray { - fn equals_json(&self, _json: &[&Value]) -> bool { - unimplemented!( - "Added to allow UnionArray to implement the Array trait: see ARROW-8547" - ) - } -} - -impl JsonEqual for NullArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - // all JSON values must be nulls - json.iter().all(|&v| v == &JNull) - } -} - -impl PartialEq for Value { - fn eq(&self, arrow: &NullArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(&json_array), - _ => false, - } - } -} - -impl PartialEq for NullArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(&json_array), - _ => false, - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -1576,582 +1234,4 @@ mod tests { } Ok(builder.finish()) } - - #[test] - fn test_primitive_json_equal() { - // Test equaled array - let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); - let json_array: Value = serde_json::from_str( - r#" - [ - 1, null, 2, 3 - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequaled array - let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); - let json_array: Value = serde_json::from_str( - r#" - [ - 1, 1, 2, 3 - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test unequal length case - let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); - let json_array: Value = serde_json::from_str( - r#" - [ - 1, 1 - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test not json array type case - let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_list_json_equal() { - // Test equal case - let arrow_array = create_list_array( - &mut ListBuilder::new(Int32Builder::new(10)), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - [ - [1, 2, 3], - null, - [4, 5, 6] - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal case - let arrow_array = create_list_array( - &mut ListBuilder::new(Int32Builder::new(10)), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - [ - [1, 2, 3], - [7, 8], - [4, 5, 6] - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let arrow_array = create_list_array( - &mut ListBuilder::new(Int32Builder::new(10)), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_fixed_size_list_json_equal() { - // Test equal case - let arrow_array = create_fixed_size_list_array( - &mut FixedSizeListBuilder::new(Int32Builder::new(10), 3), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - [ - [1, 2, 3], - null, - [4, 5, 6] - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal case - let arrow_array = create_fixed_size_list_array( - &mut FixedSizeListBuilder::new(Int32Builder::new(10), 3), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - [ - [1, 2, 3], - [7, 8, 9], - [4, 5, 6] - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let arrow_array = create_fixed_size_list_array( - &mut FixedSizeListBuilder::new(Int32Builder::new(10), 3), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_string_json_equal() { - // Test the equal case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None, None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "world", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None, None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "arrow", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test unequal length case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "arrow", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect value type case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - 1, - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_binary_json_equal() { - // Test the equal case - let mut builder = BinaryBuilder::new(6); - builder.append_value(b"hello").unwrap(); - builder.append_null().unwrap(); - builder.append_null().unwrap(); - builder.append_value(b"world").unwrap(); - builder.append_null().unwrap(); - builder.append_null().unwrap(); - let arrow_array = builder.finish(); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "world", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None, None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "arrow", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test unequal length case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "arrow", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect value type case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - 1, - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_fixed_size_binary_json_equal() { - // Test the equal case - let mut builder = FixedSizeBinaryBuilder::new(15, 5); - builder.append_value(b"hello").unwrap(); - builder.append_null().unwrap(); - builder.append_value(b"world").unwrap(); - let arrow_array: FixedSizeBinaryArray = builder.finish(); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - "world" - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal case - builder.append_value(b"hello").unwrap(); - builder.append_null().unwrap(); - builder.append_value(b"world").unwrap(); - let arrow_array: FixedSizeBinaryArray = builder.finish(); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - "arrow" - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test unequal length case - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "world" - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect value type case - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - 1 - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_struct_json_equal() { - let strings: ArrayRef = Arc::new(StringArray::from(vec![ - Some("joe"), - None, - None, - Some("mark"), - Some("doe"), - ])); - let ints: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - Some(2), - None, - Some(4), - Some(5), - ])); - - let arrow_array = - StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) - .unwrap(); - - let json_array: Value = serde_json::from_str( - r#" - [ - { - "f1": "joe", - "f2": 1 - }, - { - "f2": 2 - }, - null, - { - "f1": "mark", - "f2": 4 - }, - { - "f1": "doe", - "f2": 5 - } - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal length case - let json_array: Value = serde_json::from_str( - r#" - [ - { - "f1": "joe", - "f2": 1 - }, - { - "f2": 2 - }, - null, - { - "f1": "mark", - "f2": 4 - } - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let json_array: Value = serde_json::from_str( - r#" - { - "f1": "joe", - "f2": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test not all object case - let json_array: Value = serde_json::from_str( - r#" - [ - { - "f1": "joe", - "f2": 1 - }, - 2, - null, - { - "f1": "mark", - "f2": 4 - } - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_null_json_equal() { - // Test equaled array - let arrow_array = NullArray::new(4); - let json_array: Value = serde_json::from_str( - r#" - [ - null, null, null, null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequaled array - let arrow_array = NullArray::new(2); - let json_array: Value = serde_json::from_str( - r#" - [ - null, null, null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } } diff --git a/rust/arrow/src/array/equal_json.rs b/rust/arrow/src/array/equal_json.rs new file mode 100644 index 00000000000..d29b84529b6 --- /dev/null +++ b/rust/arrow/src/array/equal_json.rs @@ -0,0 +1,985 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; +use crate::datatypes::*; +use array::{ + Array, BinaryOffsetSizeTrait, GenericBinaryArray, GenericListArray, + GenericStringArray, OffsetSizeTrait, StringOffsetSizeTrait, +}; +use hex::FromHex; +use serde_json::value::Value::{Null as JNull, Object, String as JString}; +use serde_json::Value; + +/// Trait for comparing arrow array with json array +pub trait JsonEqual { + /// Checks whether arrow array equals to json array. + fn equals_json(&self, json: &[&Value]) -> bool; + + /// Checks whether arrow array equals to json array. + fn equals_json_values(&self, json: &[Value]) -> bool { + let refs = json.iter().collect::>(); + + self.equals_json(&refs) + } +} + +/// Implement array equals for numeric type +impl JsonEqual for PrimitiveArray { + fn equals_json(&self, json: &[&Value]) -> bool { + if self.len() != json.len() { + return false; + } + + (0..self.len()).all(|i| match json[i] { + Value::Null => self.is_null(i), + v => self.is_valid(i) && Some(v) == self.value(i).into_json_value().as_ref(), + }) + } +} + +impl PartialEq for PrimitiveArray { + fn eq(&self, json: &Value) -> bool { + match json { + Value::Array(array) => self.equals_json_values(&array), + _ => false, + } + } +} + +impl PartialEq> for Value { + fn eq(&self, arrow: &PrimitiveArray) -> bool { + match self { + Value::Array(array) => arrow.equals_json_values(&array), + _ => false, + } + } +} + +impl JsonEqual for GenericListArray { + fn equals_json(&self, json: &[&Value]) -> bool { + if self.len() != json.len() { + return false; + } + + (0..self.len()).all(|i| match json[i] { + Value::Array(v) => self.is_valid(i) && self.value(i).equals_json_values(v), + Value::Null => self.is_null(i) || self.value_length(i).is_zero(), + _ => false, + }) + } +} + +impl PartialEq for GenericListArray { + fn eq(&self, json: &Value) -> bool { + match json { + Value::Array(json_array) => self.equals_json_values(json_array), + _ => false, + } + } +} + +impl PartialEq> for Value { + fn eq(&self, arrow: &GenericListArray) -> bool { + match self { + Value::Array(json_array) => arrow.equals_json_values(json_array), + _ => false, + } + } +} + +impl JsonEqual for DictionaryArray { + fn equals_json(&self, json: &[&Value]) -> bool { + // todo: this is wrong: we must test the values also + self.keys().equals_json(json) + } +} + +impl PartialEq for DictionaryArray { + fn eq(&self, json: &Value) -> bool { + match json { + Value::Array(json_array) => self.equals_json_values(json_array), + _ => false, + } + } +} + +impl PartialEq> for Value { + fn eq(&self, arrow: &DictionaryArray) -> bool { + match self { + Value::Array(json_array) => arrow.equals_json_values(json_array), + _ => false, + } + } +} + +impl JsonEqual for FixedSizeListArray { + fn equals_json(&self, json: &[&Value]) -> bool { + if self.len() != json.len() { + return false; + } + + (0..self.len()).all(|i| match json[i] { + Value::Array(v) => self.is_valid(i) && self.value(i).equals_json_values(v), + Value::Null => self.is_null(i) || self.value_length() == 0, + _ => false, + }) + } +} + +impl PartialEq for FixedSizeListArray { + fn eq(&self, json: &Value) -> bool { + match json { + Value::Array(json_array) => self.equals_json_values(json_array), + _ => false, + } + } +} + +impl PartialEq for Value { + fn eq(&self, arrow: &FixedSizeListArray) -> bool { + match self { + Value::Array(json_array) => arrow.equals_json_values(json_array), + _ => false, + } + } +} + +impl JsonEqual for StructArray { + fn equals_json(&self, json: &[&Value]) -> bool { + if self.len() != json.len() { + return false; + } + + let all_object = json.iter().all(|v| match v { + Object(_) | JNull => true, + _ => false, + }); + + if !all_object { + return false; + } + + for column_name in self.column_names() { + let json_values = json + .iter() + .map(|obj| obj.get(column_name).unwrap_or(&Value::Null)) + .collect::>(); + + if !self + .column_by_name(column_name) + .map(|arr| arr.equals_json(&json_values)) + .unwrap_or(false) + { + return false; + } + } + + true + } +} + +impl PartialEq for StructArray { + fn eq(&self, json: &Value) -> bool { + match json { + Value::Array(json_array) => self.equals_json_values(&json_array), + _ => false, + } + } +} + +impl PartialEq for Value { + fn eq(&self, arrow: &StructArray) -> bool { + match self { + Value::Array(json_array) => arrow.equals_json_values(&json_array), + _ => false, + } + } +} + +impl JsonEqual for GenericBinaryArray { + fn equals_json(&self, json: &[&Value]) -> bool { + if self.len() != json.len() { + return false; + } + + (0..self.len()).all(|i| match json[i] { + JString(s) => { + // binary data is sometimes hex encoded, this checks if bytes are equal, + // and if not converting to hex is attempted + self.is_valid(i) + && (s.as_str().as_bytes() == self.value(i) + || Vec::from_hex(s.as_str()) == Ok(self.value(i).to_vec())) + } + JNull => self.is_null(i), + _ => false, + }) + } +} + +impl PartialEq + for GenericBinaryArray +{ + fn eq(&self, json: &Value) -> bool { + match json { + Value::Array(json_array) => self.equals_json_values(&json_array), + _ => false, + } + } +} + +impl PartialEq> + for Value +{ + fn eq(&self, arrow: &GenericBinaryArray) -> bool { + match self { + Value::Array(json_array) => arrow.equals_json_values(&json_array), + _ => false, + } + } +} + +impl JsonEqual for GenericStringArray { + fn equals_json(&self, json: &[&Value]) -> bool { + if self.len() != json.len() { + return false; + } + + (0..self.len()).all(|i| match json[i] { + JString(s) => self.is_valid(i) && s.as_str() == self.value(i), + JNull => self.is_null(i), + _ => false, + }) + } +} + +impl PartialEq + for GenericStringArray +{ + fn eq(&self, json: &Value) -> bool { + match json { + Value::Array(json_array) => self.equals_json_values(&json_array), + _ => false, + } + } +} + +impl PartialEq> + for Value +{ + fn eq(&self, arrow: &GenericStringArray) -> bool { + match self { + Value::Array(json_array) => arrow.equals_json_values(&json_array), + _ => false, + } + } +} + +impl JsonEqual for FixedSizeBinaryArray { + fn equals_json(&self, json: &[&Value]) -> bool { + if self.len() != json.len() { + return false; + } + + (0..self.len()).all(|i| match json[i] { + JString(s) => { + // binary data is sometimes hex encoded, this checks if bytes are equal, + // and if not converting to hex is attempted + self.is_valid(i) + && (s.as_str().as_bytes() == self.value(i) + || Vec::from_hex(s.as_str()) == Ok(self.value(i).to_vec())) + } + JNull => self.is_null(i), + _ => false, + }) + } +} + +impl PartialEq for FixedSizeBinaryArray { + fn eq(&self, json: &Value) -> bool { + match json { + Value::Array(json_array) => self.equals_json_values(&json_array), + _ => false, + } + } +} + +impl PartialEq for Value { + fn eq(&self, arrow: &FixedSizeBinaryArray) -> bool { + match self { + Value::Array(json_array) => arrow.equals_json_values(&json_array), + _ => false, + } + } +} + +impl JsonEqual for UnionArray { + fn equals_json(&self, _json: &[&Value]) -> bool { + unimplemented!( + "Added to allow UnionArray to implement the Array trait: see ARROW-8547" + ) + } +} + +impl JsonEqual for NullArray { + fn equals_json(&self, json: &[&Value]) -> bool { + if self.len() != json.len() { + return false; + } + + // all JSON values must be nulls + json.iter().all(|&v| v == &JNull) + } +} + +impl PartialEq for Value { + fn eq(&self, arrow: &NullArray) -> bool { + match self { + Value::Array(json_array) => arrow.equals_json_values(&json_array), + _ => false, + } + } +} + +impl PartialEq for NullArray { + fn eq(&self, json: &Value) -> bool { + match json { + Value::Array(json_array) => self.equals_json_values(&json_array), + _ => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::error::Result; + use std::{convert::TryFrom, sync::Arc}; + + fn create_list_array<'a, U: AsRef<[i32]>, T: AsRef<[Option]>>( + builder: &'a mut ListBuilder, + data: T, + ) -> Result { + for d in data.as_ref() { + if let Some(v) = d { + builder.values().append_slice(v.as_ref())?; + builder.append(true)? + } else { + builder.append(false)? + } + } + Ok(builder.finish()) + } + + /// Create a fixed size list of 2 value lengths + fn create_fixed_size_list_array<'a, U: AsRef<[i32]>, T: AsRef<[Option]>>( + builder: &'a mut FixedSizeListBuilder, + data: T, + ) -> Result { + for d in data.as_ref() { + if let Some(v) = d { + builder.values().append_slice(v.as_ref())?; + builder.append(true)? + } else { + for _ in 0..builder.value_length() { + builder.values().append_null()?; + } + builder.append(false)? + } + } + Ok(builder.finish()) + } + + #[test] + fn test_primitive_json_equal() { + // Test equaled array + let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); + let json_array: Value = serde_json::from_str( + r#" + [ + 1, null, 2, 3 + ] + "#, + ) + .unwrap(); + assert!(arrow_array.eq(&json_array)); + assert!(json_array.eq(&arrow_array)); + + // Test unequaled array + let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); + let json_array: Value = serde_json::from_str( + r#" + [ + 1, 1, 2, 3 + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test unequal length case + let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); + let json_array: Value = serde_json::from_str( + r#" + [ + 1, 1 + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test not json array type case + let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); + let json_array: Value = serde_json::from_str( + r#" + { + "a": 1 + } + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + } + + #[test] + fn test_list_json_equal() { + // Test equal case + let arrow_array = create_list_array( + &mut ListBuilder::new(Int32Builder::new(10)), + &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], + ) + .unwrap(); + let json_array: Value = serde_json::from_str( + r#" + [ + [1, 2, 3], + null, + [4, 5, 6] + ] + "#, + ) + .unwrap(); + assert!(arrow_array.eq(&json_array)); + assert!(json_array.eq(&arrow_array)); + + // Test unequal case + let arrow_array = create_list_array( + &mut ListBuilder::new(Int32Builder::new(10)), + &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], + ) + .unwrap(); + let json_array: Value = serde_json::from_str( + r#" + [ + [1, 2, 3], + [7, 8], + [4, 5, 6] + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test incorrect type case + let arrow_array = create_list_array( + &mut ListBuilder::new(Int32Builder::new(10)), + &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], + ) + .unwrap(); + let json_array: Value = serde_json::from_str( + r#" + { + "a": 1 + } + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + } + + #[test] + fn test_fixed_size_list_json_equal() { + // Test equal case + let arrow_array = create_fixed_size_list_array( + &mut FixedSizeListBuilder::new(Int32Builder::new(10), 3), + &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], + ) + .unwrap(); + let json_array: Value = serde_json::from_str( + r#" + [ + [1, 2, 3], + null, + [4, 5, 6] + ] + "#, + ) + .unwrap(); + assert!(arrow_array.eq(&json_array)); + assert!(json_array.eq(&arrow_array)); + + // Test unequal case + let arrow_array = create_fixed_size_list_array( + &mut FixedSizeListBuilder::new(Int32Builder::new(10), 3), + &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], + ) + .unwrap(); + let json_array: Value = serde_json::from_str( + r#" + [ + [1, 2, 3], + [7, 8, 9], + [4, 5, 6] + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test incorrect type case + let arrow_array = create_fixed_size_list_array( + &mut FixedSizeListBuilder::new(Int32Builder::new(10), 3), + &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], + ) + .unwrap(); + let json_array: Value = serde_json::from_str( + r#" + { + "a": 1 + } + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + } + + #[test] + fn test_string_json_equal() { + // Test the equal case + let arrow_array = + StringArray::from(vec![Some("hello"), None, None, Some("world"), None, None]); + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + null, + "world", + null, + null + ] + "#, + ) + .unwrap(); + assert!(arrow_array.eq(&json_array)); + assert!(json_array.eq(&arrow_array)); + + // Test unequal case + let arrow_array = + StringArray::from(vec![Some("hello"), None, None, Some("world"), None, None]); + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + null, + "arrow", + null, + null + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test unequal length case + let arrow_array = + StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + null, + "arrow", + null, + null + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test incorrect type case + let arrow_array = + StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); + let json_array: Value = serde_json::from_str( + r#" + { + "a": 1 + } + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test incorrect value type case + let arrow_array = + StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + null, + 1, + null, + null + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + } + + #[test] + fn test_binary_json_equal() { + // Test the equal case + let mut builder = BinaryBuilder::new(6); + builder.append_value(b"hello").unwrap(); + builder.append_null().unwrap(); + builder.append_null().unwrap(); + builder.append_value(b"world").unwrap(); + builder.append_null().unwrap(); + builder.append_null().unwrap(); + let arrow_array = builder.finish(); + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + null, + "world", + null, + null + ] + "#, + ) + .unwrap(); + assert!(arrow_array.eq(&json_array)); + assert!(json_array.eq(&arrow_array)); + + // Test unequal case + let arrow_array = + StringArray::from(vec![Some("hello"), None, None, Some("world"), None, None]); + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + null, + "arrow", + null, + null + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test unequal length case + let arrow_array = + StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + null, + "arrow", + null, + null + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test incorrect type case + let arrow_array = + StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); + let json_array: Value = serde_json::from_str( + r#" + { + "a": 1 + } + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test incorrect value type case + let arrow_array = + StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + null, + 1, + null, + null + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + } + + #[test] + fn test_fixed_size_binary_json_equal() { + // Test the equal case + let mut builder = FixedSizeBinaryBuilder::new(15, 5); + builder.append_value(b"hello").unwrap(); + builder.append_null().unwrap(); + builder.append_value(b"world").unwrap(); + let arrow_array: FixedSizeBinaryArray = builder.finish(); + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + "world" + ] + "#, + ) + .unwrap(); + assert!(arrow_array.eq(&json_array)); + assert!(json_array.eq(&arrow_array)); + + // Test unequal case + builder.append_value(b"hello").unwrap(); + builder.append_null().unwrap(); + builder.append_value(b"world").unwrap(); + let arrow_array: FixedSizeBinaryArray = builder.finish(); + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + "arrow" + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test unequal length case + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + null, + "world" + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test incorrect type case + let json_array: Value = serde_json::from_str( + r#" + { + "a": 1 + } + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test incorrect value type case + let json_array: Value = serde_json::from_str( + r#" + [ + "hello", + null, + 1 + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + } + + #[test] + fn test_struct_json_equal() { + let strings: ArrayRef = Arc::new(StringArray::from(vec![ + Some("joe"), + None, + None, + Some("mark"), + Some("doe"), + ])); + let ints: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + ])); + + let arrow_array = + StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) + .unwrap(); + + let json_array: Value = serde_json::from_str( + r#" + [ + { + "f1": "joe", + "f2": 1 + }, + { + "f2": 2 + }, + null, + { + "f1": "mark", + "f2": 4 + }, + { + "f1": "doe", + "f2": 5 + } + ] + "#, + ) + .unwrap(); + assert!(arrow_array.eq(&json_array)); + assert!(json_array.eq(&arrow_array)); + + // Test unequal length case + let json_array: Value = serde_json::from_str( + r#" + [ + { + "f1": "joe", + "f2": 1 + }, + { + "f2": 2 + }, + null, + { + "f1": "mark", + "f2": 4 + } + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test incorrect type case + let json_array: Value = serde_json::from_str( + r#" + { + "f1": "joe", + "f2": 1 + } + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + + // Test not all object case + let json_array: Value = serde_json::from_str( + r#" + [ + { + "f1": "joe", + "f2": 1 + }, + 2, + null, + { + "f1": "mark", + "f2": 4 + } + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + } + + #[test] + fn test_null_json_equal() { + // Test equaled array + let arrow_array = NullArray::new(4); + let json_array: Value = serde_json::from_str( + r#" + [ + null, null, null, null + ] + "#, + ) + .unwrap(); + assert!(arrow_array.eq(&json_array)); + assert!(json_array.eq(&arrow_array)); + + // Test unequaled array + let arrow_array = NullArray::new(2); + let json_array: Value = serde_json::from_str( + r#" + [ + null, null, null + ] + "#, + ) + .unwrap(); + assert!(arrow_array.ne(&json_array)); + assert!(json_array.ne(&arrow_array)); + } +} diff --git a/rust/arrow/src/array/mod.rs b/rust/arrow/src/array/mod.rs index 0a996e93257..88d5d74cac4 100644 --- a/rust/arrow/src/array/mod.rs +++ b/rust/arrow/src/array/mod.rs @@ -87,6 +87,7 @@ mod builder; mod cast; mod data; mod equal; +mod equal_json; mod iterator; mod null; mod ord; @@ -247,7 +248,7 @@ pub use self::iterator::*; // --------------------- Array Equality --------------------- pub use self::equal::ArrayEqual; -pub use self::equal::JsonEqual; +pub use self::equal_json::JsonEqual; // --------------------- Array's values comparison --------------------- From 8cde0957dbcae38dd4baaf1a6896ec59ef6bf682 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Wed, 28 Oct 2020 22:25:42 +0100 Subject: [PATCH 2/4] Added bench for equality. --- rust/arrow/Cargo.toml | 4 ++ rust/arrow/benches/equal.rs | 93 +++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 rust/arrow/benches/equal.rs diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index c7558e76a40..71445768207 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -112,3 +112,7 @@ harness = false [[bench]] name = "csv_writer" harness = false + +[[bench]] +name = "equal" +harness = false diff --git a/rust/arrow/benches/equal.rs b/rust/arrow/benches/equal.rs new file mode 100644 index 00000000000..b70b633ff6d --- /dev/null +++ b/rust/arrow/benches/equal.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +use rand::distributions::Alphanumeric; +use rand::Rng; +use std::sync::Arc; + +extern crate arrow; + +use arrow::array::*; + +fn create_string_array(size: usize, with_nulls: bool) -> ArrayRef { + // use random numbers to avoid spurious compiler optimizations wrt to branching + let mut rng = rand::thread_rng(); + let mut builder = StringBuilder::new(size); + + for _ in 0..size { + if with_nulls && rng.gen::() > 0.5 { + builder.append_null().unwrap(); + } else { + let string = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(10) + .collect::(); + builder.append_value(&string).unwrap(); + } + } + Arc::new(builder.finish()) +} + +fn create_array(size: usize, with_nulls: bool) -> ArrayRef { + // use random numbers to avoid spurious compiler optimizations wrt to branching + let mut rng = rand::thread_rng(); + let mut builder = Float32Builder::new(size); + + for _ in 0..size { + if with_nulls && rng.gen::() > 0.5 { + builder.append_null().unwrap(); + } else { + builder.append_value(rng.gen()).unwrap(); + } + } + Arc::new(builder.finish()) +} + +fn bench_equal(arr_a: &ArrayRef) { + let arr_a = arr_a.as_any().downcast_ref::().unwrap(); + criterion::black_box(arr_a.equals(arr_a)); +} + +fn bench_equal_string(arr_a: &ArrayRef) { + let arr_a = arr_a.as_any().downcast_ref::().unwrap(); + criterion::black_box(arr_a.equals(arr_a)); +} + +fn add_benchmark(c: &mut Criterion) { + let arr_a = create_array(512, false); + c.bench_function("equal_512", |b| b.iter(|| bench_equal(&arr_a))); + + let arr_a_nulls = create_array(512, true); + c.bench_function("equal_nulls_512", |b| b.iter(|| bench_equal(&arr_a_nulls))); + + let arr_a = create_string_array(512, false); + c.bench_function("equal_string_512", |b| { + b.iter(|| bench_equal_string(&arr_a)) + }); + + let arr_a_nulls = create_string_array(512, true); + c.bench_function("equal_string_nulls_512", |b| { + b.iter(|| bench_equal_string(&arr_a_nulls)) + }); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); From eec30c7804d7bc6564be736aa6199930cfdff654 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 25 Oct 2020 07:22:39 +0100 Subject: [PATCH 3/4] Moved function to inside ArrayData and unit-tested it. --- rust/arrow/src/array/array.rs | 33 ++-------------- rust/arrow/src/array/data.rs | 71 +++++++++++++++++++++++++++-------- rust/arrow/src/array/equal.rs | 6 +-- 3 files changed, 62 insertions(+), 48 deletions(-) diff --git a/rust/arrow/src/array/array.rs b/rust/arrow/src/array/array.rs index e0863a09648..3a202c0b5c4 100644 --- a/rust/arrow/src/array/array.rs +++ b/rust/arrow/src/array/array.rs @@ -115,7 +115,7 @@ pub trait Array: fmt::Debug + Send + Sync + ArrayEqual + JsonEqual { /// assert!(array_slice.equals(&Int32Array::from(vec![2, 3, 4]))); /// ``` fn slice(&self, offset: usize, length: usize) -> ArrayRef { - make_array(slice_data(self.data_ref(), offset, length)) + make_array(Arc::new(self.data_ref().as_ref().slice(offset, length))) } /// Returns the length (i.e., number of elements) of this array. @@ -338,33 +338,6 @@ pub fn make_array(data: ArrayDataRef) -> ArrayRef { } } -/// Creates a zero-copy slice of the array's data. -/// -/// # Panics -/// -/// Panics if `offset + length > data.len()`. -fn slice_data(data: &ArrayDataRef, mut offset: usize, length: usize) -> ArrayDataRef { - assert!((offset + length) <= data.len()); - - let mut new_data = data.as_ref().clone(); - let len = std::cmp::min(new_data.len - offset, length); - - offset += data.offset; - new_data.len = len; - new_data.offset = offset; - - // Calculate the new null count based on the offset - new_data.null_count = if let Some(bitmap) = new_data.null_bitmap() { - let valid_bits = bitmap.bits.data(); - len.checked_sub(bit_util::count_set_bits_offset(valid_bits, offset, length)) - .unwrap() - } else { - 0 - }; - - Arc::new(new_data) -} - // creates a new MutableBuffer initializes all falsed // this is useful to populate null bitmaps fn make_null_buffer(len: usize) -> MutableBuffer { @@ -1935,8 +1908,8 @@ impl From for StructArray { fn from(data: ArrayDataRef) -> Self { let mut boxed_fields = vec![]; for cd in data.child_data() { - let child_data = if data.offset != 0 || data.len != cd.len { - slice_data(&cd, data.offset, data.len) + let child_data = if data.offset() != 0 || data.len() != cd.len() { + Arc::new(cd.as_ref().slice(data.offset(), data.len())) } else { cd.clone() }; diff --git a/rust/arrow/src/array/data.rs b/rust/arrow/src/array/data.rs index a1426a6fb88..f7ad3bb919d 100644 --- a/rust/arrow/src/array/data.rs +++ b/rust/arrow/src/array/data.rs @@ -26,6 +26,15 @@ use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::util::bit_util; +fn count_nulls(null_bit_buffer: Option<&Buffer>, offset: usize, len: usize) -> usize { + if let Some(ref buf) = null_bit_buffer { + len.checked_sub(bit_util::count_set_bits_offset(buf.data(), offset, len)) + .unwrap() + } else { + 0 + } +} + /// An generic representation of Arrow array data which encapsulates common attributes and /// operations for Arrow array. Specific operations for different arrays types (e.g., /// primitive, list, struct) are implemented in `Array`. @@ -35,13 +44,13 @@ pub struct ArrayData { data_type: DataType, /// The number of elements in this array data - pub(crate) len: usize, + len: usize, /// The number of null elements in this array data - pub(crate) null_count: usize, + null_count: usize, /// The offset into this array data - pub(crate) offset: usize, + offset: usize, /// The buffers for this array data. Note that depending on the array types, this /// could hold different kinds of buffers (e.g., value buffer, value offset buffer) @@ -70,18 +79,7 @@ impl ArrayData { child_data: Vec, ) -> Self { let null_count = match null_count { - None => { - if let Some(ref buf) = null_bit_buffer { - len.checked_sub(bit_util::count_set_bits_offset( - buf.data(), - offset, - len, - )) - .unwrap() - } else { - 0 - } - } + None => count_nulls(null_bit_buffer.as_ref(), offset, len), Some(null_count) => null_count, }; let null_bitmap = null_bit_buffer.map(Bitmap::from); @@ -207,6 +205,26 @@ impl ArrayData { size } + + /// Creates a zero-copy slice of itself. This creates a new [ArrayData] + /// with a different offset, len and a shifted null bitmap. + /// + /// # Panics + /// + /// Panics if `offset + length > self.len()`. + pub fn slice(&self, offset: usize, length: usize) -> ArrayData { + assert!((offset + length) <= self.len()); + + let mut new_data = self.clone(); + + new_data.len = length; + new_data.offset = offset + self.offset; + + new_data.null_count = + count_nulls(new_data.null_buffer(), new_data.offset, new_data.len); + + new_data + } } impl PartialEq for ArrayData { @@ -435,4 +453,27 @@ mod tests { assert!(arr_data.null_buffer().is_some()); assert_eq!(&bit_v, arr_data.null_buffer().unwrap().data()); } + + #[test] + fn test_slice() { + let mut bit_v: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut bit_v, 0); + bit_util::set_bit(&mut bit_v, 3); + bit_util::set_bit(&mut bit_v, 10); + let data = ArrayData::builder(DataType::Int32) + .len(16) + .null_bit_buffer(Buffer::from(bit_v)) + .build(); + let data = data.as_ref(); + let new_data = data.slice(1, 15); + assert_eq!(data.len() - 1, new_data.len()); + assert_eq!(1, new_data.offset()); + assert_eq!(data.null_count(), new_data.null_count()); + + // slice of a slice (removes one null) + let new_data = new_data.slice(1, 14); + assert_eq!(data.len() - 2, new_data.len()); + assert_eq!(2, new_data.offset()); + assert_eq!(data.null_count() - 1, new_data.null_count()); + } } diff --git a/rust/arrow/src/array/equal.rs b/rust/arrow/src/array/equal.rs index 4c549286bff..a747b83a11b 100644 --- a/rust/arrow/src/array/equal.rs +++ b/rust/arrow/src/array/equal.rs @@ -779,13 +779,13 @@ fn base_equal(this: &ArrayDataRef, other: &ArrayDataRef) -> bool { if this.data_type() != other.data_type() { return false; } - if this.len != other.len { + if this.len() != other.len() { return false; } - if this.null_count != other.null_count { + if this.null_count() != other.null_count() { return false; } - if this.null_count > 0 { + if this.null_count() > 0 { let null_bitmap = this.null_bitmap().as_ref().unwrap(); let other_null_bitmap = other.null_bitmap().as_ref().unwrap(); let null_buf = null_bitmap.bits.data(); From 5801f5c06a4feffe2b876936d10ce123eabc9779 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Mon, 26 Oct 2020 04:17:44 +0100 Subject: [PATCH 4/4] Refactored equality. --- rust/arrow/benches/equal.rs | 14 +- rust/arrow/src/array/array.rs | 45 +- rust/arrow/src/array/builder.rs | 47 +- rust/arrow/src/array/data.rs | 24 +- rust/arrow/src/array/equal.rs | 1237 ------------------- rust/arrow/src/array/equal/boolean.rs | 49 + rust/arrow/src/array/equal/dictionary.rs | 67 + rust/arrow/src/array/equal/fixed_binary.rs | 65 + rust/arrow/src/array/equal/fixed_list.rs | 65 + rust/arrow/src/array/equal/list.rs | 117 ++ rust/arrow/src/array/equal/mod.rs | 831 +++++++++++++ rust/arrow/src/array/equal/null.rs | 31 + rust/arrow/src/array/equal/primitive.rs | 63 + rust/arrow/src/array/equal/structure.rs | 59 + rust/arrow/src/array/equal/utils.rs | 76 ++ rust/arrow/src/array/equal/variable_size.rs | 91 ++ rust/arrow/src/array/mod.rs | 1 - rust/arrow/src/compute/kernels/concat.rs | 30 +- rust/arrow/src/compute/kernels/sort.rs | 46 +- rust/arrow/src/compute/kernels/take.rs | 43 +- 20 files changed, 1594 insertions(+), 1407 deletions(-) delete mode 100644 rust/arrow/src/array/equal.rs create mode 100644 rust/arrow/src/array/equal/boolean.rs create mode 100644 rust/arrow/src/array/equal/dictionary.rs create mode 100644 rust/arrow/src/array/equal/fixed_binary.rs create mode 100644 rust/arrow/src/array/equal/fixed_list.rs create mode 100644 rust/arrow/src/array/equal/list.rs create mode 100644 rust/arrow/src/array/equal/mod.rs create mode 100644 rust/arrow/src/array/equal/null.rs create mode 100644 rust/arrow/src/array/equal/primitive.rs create mode 100644 rust/arrow/src/array/equal/structure.rs create mode 100644 rust/arrow/src/array/equal/utils.rs create mode 100644 rust/arrow/src/array/equal/variable_size.rs diff --git a/rust/arrow/benches/equal.rs b/rust/arrow/benches/equal.rs index b70b633ff6d..a73b70e1011 100644 --- a/rust/arrow/benches/equal.rs +++ b/rust/arrow/benches/equal.rs @@ -62,13 +62,7 @@ fn create_array(size: usize, with_nulls: bool) -> ArrayRef { } fn bench_equal(arr_a: &ArrayRef) { - let arr_a = arr_a.as_any().downcast_ref::().unwrap(); - criterion::black_box(arr_a.equals(arr_a)); -} - -fn bench_equal_string(arr_a: &ArrayRef) { - let arr_a = arr_a.as_any().downcast_ref::().unwrap(); - criterion::black_box(arr_a.equals(arr_a)); + criterion::black_box(arr_a == arr_a); } fn add_benchmark(c: &mut Criterion) { @@ -79,13 +73,11 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("equal_nulls_512", |b| b.iter(|| bench_equal(&arr_a_nulls))); let arr_a = create_string_array(512, false); - c.bench_function("equal_string_512", |b| { - b.iter(|| bench_equal_string(&arr_a)) - }); + c.bench_function("equal_string_512", |b| b.iter(|| bench_equal(&arr_a))); let arr_a_nulls = create_string_array(512, true); c.bench_function("equal_string_nulls_512", |b| { - b.iter(|| bench_equal_string(&arr_a_nulls)) + b.iter(|| bench_equal(&arr_a_nulls)) }); } diff --git a/rust/arrow/src/array/array.rs b/rust/arrow/src/array/array.rs index 3a202c0b5c4..b7e13a03ad7 100644 --- a/rust/arrow/src/array/array.rs +++ b/rust/arrow/src/array/array.rs @@ -50,7 +50,7 @@ const NANOSECONDS: i64 = 1_000_000_000; /// Trait for dealing with different types of array at runtime when the type of the /// array is not known in advance. -pub trait Array: fmt::Debug + Send + Sync + ArrayEqual + JsonEqual { +pub trait Array: fmt::Debug + Send + Sync + JsonEqual { /// Returns the array as [`Any`](std::any::Any) so that it can be /// downcasted to a specific implementation. /// @@ -112,7 +112,7 @@ pub trait Array: fmt::Debug + Send + Sync + ArrayEqual + JsonEqual { /// // Make slice over the values [2, 3, 4] /// let array_slice = array.slice(1, 3); /// - /// assert!(array_slice.equals(&Int32Array::from(vec![2, 3, 4]))); + /// assert_eq!(array_slice.as_ref(), &Int32Array::from(vec![2, 3, 4])); /// ``` fn slice(&self, offset: usize, length: usize) -> ArrayRef { make_array(Arc::new(self.data_ref().as_ref().slice(offset, length))) @@ -182,8 +182,7 @@ pub trait Array: fmt::Debug + Send + Sync + ArrayEqual + JsonEqual { /// assert_eq!(array.is_null(1), true); /// ``` fn is_null(&self, index: usize) -> bool { - let data = self.data_ref(); - data.is_null(data.offset() + index) + self.data().is_null(index) } /// Returns whether the element at `index` is not null. @@ -200,8 +199,7 @@ pub trait Array: fmt::Debug + Send + Sync + ArrayEqual + JsonEqual { /// assert_eq!(array.is_valid(1), false); /// ``` fn is_valid(&self, index: usize) -> bool { - let data = self.data_ref(); - data.is_valid(data.offset() + index) + self.data().is_valid(index) } /// Returns the total number of null values in this array. @@ -826,11 +824,6 @@ impl From for PrimitiveArray { } } -/// Common operations for List types. -pub trait ListArrayOps { - fn value_offset_at(&self, i: usize) -> OffsetSize; -} - /// trait declaring an offset size, relevant for i32 vs i64 array types. pub trait OffsetSizeTrait: ArrowNativeType + Num + Ord { fn prefix() -> &'static str; @@ -1006,14 +999,6 @@ impl fmt::Debug for GenericListArray { } } -impl ListArrayOps - for GenericListArray -{ - fn value_offset_at(&self, i: usize) -> OffsetSize { - self.value_offset_at(i) - } -} - /// A list array where each element is a variable-sized sequence of values with the same /// type whose memory offsets between elements are represented by a i32. pub type ListArray = GenericListArray; @@ -1300,14 +1285,6 @@ impl Array for GenericBinaryArray } } -impl ListArrayOps - for GenericBinaryArray -{ - fn value_offset_at(&self, i: usize) -> OffsetSize { - self.value_offset_at(i) - } -} - impl From for GenericBinaryArray { @@ -1664,14 +1641,6 @@ impl From } } -impl ListArrayOps - for GenericStringArray -{ - fn value_offset_at(&self, i: usize) -> OffsetSize { - self.value_offset_at(i) - } -} - /// An array where each element is a variable-sized sequence of bytes representing a string /// whose maximum length (in bytes) is represented by a i32. pub type StringArray = GenericStringArray; @@ -1767,12 +1736,6 @@ impl FixedSizeBinaryArray { } } -impl ListArrayOps for FixedSizeBinaryArray { - fn value_offset_at(&self, i: usize) -> i32 { - self.value_offset_at(i) - } -} - impl From for FixedSizeBinaryArray { fn from(data: ArrayDataRef) -> Self { assert_eq!( diff --git a/rust/arrow/src/array/builder.rs b/rust/arrow/src/array/builder.rs index 21ddb7dfe29..f1037c735fb 100644 --- a/rust/arrow/src/array/builder.rs +++ b/rust/arrow/src/array/builder.rs @@ -531,7 +531,7 @@ impl ArrayBuilder for PrimitiveBuilder { for i in 0..len { // account for offset as `ArrayData` does not - self.bitmap_builder.append(array.is_valid(offset + i))?; + self.bitmap_builder.append(array.is_valid(i))?; } } Ok(()) @@ -761,8 +761,7 @@ where .append_slice(adjusted_offsets.as_slice())?; for i in 0..len { - // account for offset as `ArrayData` does not - self.bitmap_builder.append(array.is_valid(offset + i))?; + self.bitmap_builder.append(array.is_valid(i))?; } } @@ -974,8 +973,7 @@ where .append_slice(adjusted_offsets.as_slice())?; for i in 0..len { - // account for offset as `ArrayData` does not - self.bitmap_builder.append(array.is_valid(offset + i))?; + self.bitmap_builder.append(array.is_valid(i))?; } } @@ -1156,8 +1154,7 @@ where let sliced = child_array.slice(first_offset, offset_at_len - first_offset); self.values().append_data(&[sliced.data()])?; for i in 0..len { - // account for offset as `ArrayData` does not - self.bitmap_builder.append(array.is_valid(offset + i))?; + self.bitmap_builder.append(array.is_valid(i))?; } } @@ -1963,8 +1960,7 @@ impl ArrayBuilder for StructBuilder { builder.append_data(&[sliced.data()])?; } for i in 0..len { - // account for offset as `ArrayData` does not - self.bitmap_builder.append(array.is_valid(offset + i))?; + self.bitmap_builder.append(array.is_valid(i))?; } } @@ -3558,7 +3554,7 @@ mod tests { array.slice(2, 0).data(), ])?; let finished = builder.finish(); - let expected = Arc::new(Int32Array::from(vec![ + let expected = Int32Array::from(vec![ None, Some(1), None, @@ -3567,14 +3563,15 @@ mod tests { None, Some(6), Some(7), + // array.data() end Some(3), None, None, Some(6), - ])) as ArrayRef; + ]); assert_eq!(finished.len(), expected.len()); assert_eq!(finished.null_count(), expected.null_count()); - assert!(finished.equals(&(*expected))); + assert_eq!(finished, expected); let mut builder = Float64Builder::new(64); builder.append_null()?; @@ -3588,7 +3585,7 @@ mod tests { array.slice(2, 1).data(), ])?; let finished = builder.finish(); - let expected = Arc::new(Float64Array::from(vec![ + let expected = Float64Array::from(vec![ None, Some(1.0), None, @@ -3603,10 +3600,10 @@ mod tests { Some(6.0), Some(7.0), None, - ])) as ArrayRef; + ]); assert_eq!(finished.len(), expected.len()); assert_eq!(finished.null_count(), expected.null_count()); - assert!(finished.equals(&(*expected))); + assert_eq!(finished, expected); Ok(()) } @@ -3630,7 +3627,7 @@ mod tests { array.slice(2, 0).data(), ])?; let finished = builder.finish(); - let expected = Arc::new(BooleanArray::from(vec![ + let expected = BooleanArray::from(vec![ None, Some(true), None, @@ -3643,10 +3640,10 @@ mod tests { None, None, Some(false), - ])) as ArrayRef; + ]); assert_eq!(finished.len(), expected.len()); assert_eq!(finished.null_count(), expected.null_count()); - assert!(finished.equals(&(*expected))); + assert_eq!(finished, expected); Ok(()) } @@ -3712,7 +3709,7 @@ mod tests { finished.data().buffers()[0].data(), expected_list.data().buffers()[0].data() ); - assert!(expected_list.values().equals(&*finished.values())); + assert_eq!(&expected_list.values(), &finished.values()); assert_eq!(expected_list.len(), finished.len()); Ok(()) @@ -3802,7 +3799,7 @@ mod tests { finished.data().child_data()[0].buffers()[0].data(), expected_list.data().child_data()[0].buffers()[0].data() ); - assert!(expected_list.values().equals(&*finished.values())); + assert_eq!(&expected_list.values(), &finished.values()); assert_eq!(expected_list.len(), finished.len()); Ok(()) @@ -3879,7 +3876,7 @@ mod tests { finished.data().child_data()[0].buffers()[0].data(), expected_list.data().child_data()[0].buffers()[0].data() ); - assert!(expected_list.values().equals(&*finished.values())); + assert_eq!(&expected_list.values(), &finished.values()); assert_eq!(expected_list.len(), finished.len()); Ok(()) @@ -3963,7 +3960,7 @@ mod tests { ); let expected_list = FixedSizeListArray::from(Arc::new(expected_list_data) as ArrayDataRef); - assert!(expected_list.values().equals(&*finished.values())); + assert_eq!(&expected_list.values(), &finished.values()); assert_eq!(expected_list.len(), finished.len()); Ok(()) @@ -4037,7 +4034,7 @@ mod tests { let expected_list = FixedSizeListArray::from(Arc::new(expected_list_data) as ArrayDataRef); let expected_list = FixedSizeBinaryArray::from(expected_list); - // assert!(expected_list.values().equals(&*finished.values())); + // assert_eq!(expected_list.values(), finished.values()); assert_eq!(expected_list.len(), finished.len()); Ok(()) @@ -4111,10 +4108,10 @@ mod tests { true, true, true, false, true, false, true, false, true, false, true, false, true, false, ])) as ArrayRef; - let expected = Arc::new(StructArray::from(vec![(field1, f1), (field2, f2)])); + let expected = StructArray::from(vec![(field1, f1), (field2, f2)]); assert_eq!(arr2.data().child_data()[0], expected.data().child_data()[0]); assert_eq!(arr2.data().child_data()[1], expected.data().child_data()[1]); - assert!(arr2.equals(&*expected)); + assert_eq!(arr2, expected); Ok(()) } diff --git a/rust/arrow/src/array/data.rs b/rust/arrow/src/array/data.rs index f7ad3bb919d..e328c01ee7a 100644 --- a/rust/arrow/src/array/data.rs +++ b/rust/arrow/src/array/data.rs @@ -21,11 +21,12 @@ use std::mem; use std::sync::Arc; -use crate::bitmap::Bitmap; use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::util::bit_util; +use crate::{bitmap::Bitmap, datatypes::ArrowNativeType}; +#[inline] fn count_nulls(null_bit_buffer: Option<&Buffer>, offset: usize, len: usize) -> usize { if let Some(ref buf) = null_bit_buffer { len.checked_sub(bit_util::count_set_bits_offset(buf.data(), offset, len)) @@ -49,7 +50,7 @@ pub struct ArrayData { /// The number of null elements in this array data null_count: usize, - /// The offset into this array data + /// The offset into this array data, in number of items offset: usize, /// The buffers for this array data. Note that depending on the array types, this @@ -119,7 +120,7 @@ impl ArrayData { /// Returns whether the element at index `i` is null pub fn is_null(&self, i: usize) -> bool { if let Some(ref b) = self.null_bitmap { - return !b.is_set(i); + return !b.is_set(self.offset + i); } false } @@ -138,7 +139,7 @@ impl ArrayData { /// Returns whether the element at index `i` is not null pub fn is_valid(&self, i: usize) -> bool { if let Some(ref b) = self.null_bitmap { - return b.is_set(i); + return b.is_set(self.offset + i); } true } @@ -225,6 +226,21 @@ impl ArrayData { new_data } + + /// Returns the `buffer` as a slice of type `T` starting at self.offset + /// # Panics + /// This function panics if: + /// * the buffer is not byte-aligned with type T, or + /// * the datatype is `Boolean` (it corresponds to a bit-packed buffer where the offset is not applicable) + #[inline] + pub(super) fn buffer(&self, buffer: usize) -> &[T] { + let values = unsafe { self.buffers[buffer].data().align_to::() }; + if values.0.len() != 0 || values.2.len() != 0 { + panic!("The buffer is not byte-aligned with its interpretation") + }; + assert_ne!(self.data_type, DataType::Boolean); + &values.1[self.offset..] + } } impl PartialEq for ArrayData { diff --git a/rust/arrow/src/array/equal.rs b/rust/arrow/src/array/equal.rs deleted file mode 100644 index a747b83a11b..00000000000 --- a/rust/arrow/src/array/equal.rs +++ /dev/null @@ -1,1237 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; -use crate::datatypes::*; -use crate::util::bit_util; -use array::{ - Array, BinaryOffsetSizeTrait, GenericBinaryArray, GenericListArray, - GenericStringArray, ListArrayOps, OffsetSizeTrait, StringOffsetSizeTrait, -}; - -/// Trait for `Array` equality. -pub trait ArrayEqual { - /// Returns true if this array is equal to the `other` array - fn equals(&self, other: &dyn Array) -> bool; - - /// Returns true if the range [start_idx, end_idx) is equal to - /// [other_start_idx, other_start_idx + end_idx - start_idx) in the `other` array - fn range_equals( - &self, - other: &dyn Array, - start_idx: usize, - end_idx: usize, - other_start_idx: usize, - ) -> bool; -} - -impl ArrayEqual for PrimitiveArray { - fn equals(&self, other: &dyn Array) -> bool { - if !base_equal(&self.data(), &other.data()) { - return false; - } - - if T::DATA_TYPE == DataType::Boolean { - return bool_equal(self, other); - } - - let value_buf = self.data_ref().buffers()[0].clone(); - let other_value_buf = other.data_ref().buffers()[0].clone(); - let byte_width = T::get_bit_width() / 8; - - if self.null_count() > 0 { - let values = value_buf.data(); - let other_values = other_value_buf.data(); - - for i in 0..self.len() { - if self.is_valid(i) { - let start = (i + self.offset()) * byte_width; - let data = &values[start..(start + byte_width)]; - let other_start = (i + other.offset()) * byte_width; - let other_data = - &other_values[other_start..(other_start + byte_width)]; - if data != other_data { - return false; - } - } - } - } else { - let start = self.offset() * byte_width; - let other_start = other.offset() * byte_width; - let len = self.len() * byte_width; - let data = &value_buf.data()[start..(start + len)]; - let other_data = &other_value_buf.data()[other_start..(other_start + len)]; - if data != other_data { - return false; - } - } - - true - } - - fn range_equals( - &self, - other: &dyn Array, - start_idx: usize, - end_idx: usize, - other_start_idx: usize, - ) -> bool { - assert!(other_start_idx + (end_idx - start_idx) <= other.len()); - let other = other.as_any().downcast_ref::>().unwrap(); - - let mut j = other_start_idx; - for i in start_idx..end_idx { - let is_null = self.is_null(i); - let other_is_null = other.is_null(j); - if is_null != other_is_null || (!is_null && self.value(i) != other.value(j)) { - return false; - } - j += 1; - } - - true - } -} - -fn bool_equal(lhs: &Array, rhs: &Array) -> bool { - let values = lhs.data_ref().buffers()[0].data(); - let other_values = rhs.data_ref().buffers()[0].data(); - - // TODO: we can do this more efficiently if all values are not-null - for i in 0..lhs.len() { - if lhs.is_valid(i) - && bit_util::get_bit(values, i + lhs.offset()) - != bit_util::get_bit(other_values, i + rhs.offset()) - { - return false; - } - } - true -} - -impl PartialEq for PrimitiveArray { - fn eq(&self, other: &PrimitiveArray) -> bool { - self.equals(other) - } -} - -impl PartialEq for BooleanArray { - fn eq(&self, other: &BooleanArray) -> bool { - self.equals(other) - } -} - -impl PartialEq for GenericStringArray { - fn eq(&self, other: &Self) -> bool { - self.equals(other) - } -} - -impl PartialEq for GenericBinaryArray { - fn eq(&self, other: &Self) -> bool { - self.equals(other) - } -} - -impl PartialEq for FixedSizeBinaryArray { - fn eq(&self, other: &Self) -> bool { - self.equals(other) - } -} - -impl ArrayEqual for GenericListArray { - fn equals(&self, other: &dyn Array) -> bool { - if !base_equal(&self.data(), &other.data()) { - return false; - } - - let other = other - .as_any() - .downcast_ref::>() - .unwrap(); - - if !value_offset_equal(self, other) { - return false; - } - - if !self.values().range_equals( - &*other.values(), - self.value_offset(0).to_usize().unwrap(), - self.value_offset(self.len()).to_usize().unwrap(), - other.value_offset(0).to_usize().unwrap(), - ) { - return false; - } - - true - } - - fn range_equals( - &self, - other: &dyn Array, - start_idx: usize, - end_idx: usize, - other_start_idx: usize, - ) -> bool { - assert!(other_start_idx + (end_idx - start_idx) <= other.len()); - - let other = other - .as_any() - .downcast_ref::>() - .unwrap(); - - let mut j = other_start_idx; - for i in start_idx..end_idx { - let is_null = self.is_null(i); - let other_is_null = other.is_null(j); - - if is_null != other_is_null { - return false; - } - - if is_null { - continue; - } - - let start_offset = self.value_offset(i).to_usize().unwrap(); - let end_offset = self.value_offset(i + 1).to_usize().unwrap(); - let other_start_offset = other.value_offset(j).to_usize().unwrap(); - let other_end_offset = other.value_offset(j + 1).to_usize().unwrap(); - - if end_offset - start_offset != other_end_offset - other_start_offset { - return false; - } - - if !self.values().range_equals( - other, - start_offset, - end_offset, - other_start_offset, - ) { - return false; - } - - j += 1; - } - - true - } -} - -impl ArrayEqual for DictionaryArray { - fn equals(&self, other: &dyn Array) -> bool { - self.range_equals(other, 0, self.len(), 0) - } - - fn range_equals( - &self, - other: &dyn Array, - start_idx: usize, - end_idx: usize, - other_start_idx: usize, - ) -> bool { - assert!(other_start_idx + (end_idx - start_idx) <= other.len()); - let other = other.as_any().downcast_ref::>().unwrap(); - - // For now, all the values must be the same - self.keys() - .range_equals(other.keys(), start_idx, end_idx, other_start_idx) - && self - .values() - .range_equals(&*other.values(), 0, other.values().len(), 0) - } -} - -impl ArrayEqual for FixedSizeListArray { - fn equals(&self, other: &dyn Array) -> bool { - if !base_equal(&self.data(), &other.data()) { - return false; - } - - let other = other.as_any().downcast_ref::().unwrap(); - - if !self.values().range_equals( - &*other.values(), - self.value_offset(0) as usize, - self.value_offset(self.len()) as usize, - other.value_offset(0) as usize, - ) { - return false; - } - - true - } - - fn range_equals( - &self, - other: &dyn Array, - start_idx: usize, - end_idx: usize, - other_start_idx: usize, - ) -> bool { - assert!(other_start_idx + (end_idx - start_idx) <= other.len()); - let other = other.as_any().downcast_ref::().unwrap(); - - let mut j = other_start_idx; - for i in start_idx..end_idx { - let is_null = self.is_null(i); - let other_is_null = other.is_null(j); - - if is_null != other_is_null { - return false; - } - - if is_null { - continue; - } - - let start_offset = self.value_offset(i) as usize; - let end_offset = self.value_offset(i + 1) as usize; - let other_start_offset = other.value_offset(j) as usize; - let other_end_offset = other.value_offset(j + 1) as usize; - - if end_offset - start_offset != other_end_offset - other_start_offset { - return false; - } - - if !self.values().range_equals( - &*other.values(), - start_offset, - end_offset, - other_start_offset, - ) { - return false; - } - - j += 1; - } - - true - } -} - -impl ArrayEqual for GenericBinaryArray { - fn equals(&self, other: &dyn Array) -> bool { - if !base_equal(&self.data(), &other.data()) { - return false; - } - - let other = other - .as_any() - .downcast_ref::>() - .unwrap(); - - if !value_offset_equal(self, other) { - return false; - } - - // TODO: handle null & length == 0 case? - - let value_buf = self.value_data(); - let other_value_buf = other.value_data(); - let value_data = value_buf.data(); - let other_value_data = other_value_buf.data(); - - if self.null_count() == 0 { - // No offset in both - just do memcmp - if self.offset() == 0 && other.offset() == 0 { - let len = self.value_offset(self.len()).to_usize().unwrap(); - return value_data[..len] == other_value_data[..len]; - } else { - let start = self.value_offset(0).to_usize().unwrap(); - let other_start = other.value_offset(0).to_usize().unwrap(); - let len = (self.value_offset(self.len()) - self.value_offset(0)) - .to_usize() - .unwrap(); - return value_data[start..(start + len)] - == other_value_data[other_start..(other_start + len)]; - } - } else { - for i in 0..self.len() { - if self.is_null(i) { - continue; - } - - let start = self.value_offset(i).to_usize().unwrap(); - let other_start = other.value_offset(i).to_usize().unwrap(); - let len = self.value_length(i).to_usize().unwrap(); - if value_data[start..(start + len)] - != other_value_data[other_start..(other_start + len)] - { - return false; - } - } - } - - true - } - - fn range_equals( - &self, - other: &dyn Array, - start_idx: usize, - end_idx: usize, - other_start_idx: usize, - ) -> bool { - assert!(other_start_idx + (end_idx - start_idx) <= other.len()); - let other = other - .as_any() - .downcast_ref::>() - .unwrap(); - - let mut j = other_start_idx; - for i in start_idx..end_idx { - let is_null = self.is_null(i); - let other_is_null = other.is_null(j); - - if is_null != other_is_null { - return false; - } - - if is_null { - continue; - } - - let start_offset = self.value_offset(i).to_usize().unwrap(); - let end_offset = self.value_offset(i + 1).to_usize().unwrap(); - let other_start_offset = other.value_offset(j).to_usize().unwrap(); - let other_end_offset = other.value_offset(j + 1).to_usize().unwrap(); - - if end_offset - start_offset != other_end_offset - other_start_offset { - return false; - } - - let value_buf = self.value_data(); - let other_value_buf = other.value_data(); - let value_data = value_buf.data(); - let other_value_data = other_value_buf.data(); - - if end_offset - start_offset > 0 { - let len = end_offset - start_offset; - if value_data[start_offset..(start_offset + len)] - != other_value_data[other_start_offset..(other_start_offset + len)] - { - return false; - } - } - - j += 1; - } - - true - } -} - -impl ArrayEqual for GenericStringArray { - fn equals(&self, other: &dyn Array) -> bool { - if !base_equal(&self.data(), &other.data()) { - return false; - } - - let other = other - .as_any() - .downcast_ref::>() - .unwrap(); - - if !value_offset_equal(self, other) { - return false; - } - - // TODO: handle null & length == 0 case? - - let value_buf = self.value_data(); - let other_value_buf = other.value_data(); - let value_data = value_buf.data(); - let other_value_data = other_value_buf.data(); - - if self.null_count() == 0 { - // No offset in both - just do memcmp - if self.offset() == 0 && other.offset() == 0 { - let len = self.value_offset(self.len()).to_usize().unwrap(); - return value_data[..len] == other_value_data[..len]; - } else { - let start = self.value_offset(0).to_usize().unwrap(); - let other_start = other.value_offset(0).to_usize().unwrap(); - let len = (self.value_offset(self.len()) - self.value_offset(0)) - .to_usize() - .unwrap(); - return value_data[start..(start + len)] - == other_value_data[other_start..(other_start + len)]; - } - } else { - for i in 0..self.len() { - if self.is_null(i) { - continue; - } - - let start = self.value_offset(i).to_usize().unwrap(); - let other_start = other.value_offset(i).to_usize().unwrap(); - let len = self.value_length(i).to_usize().unwrap(); - if value_data[start..(start + len)] - != other_value_data[other_start..(other_start + len)] - { - return false; - } - } - } - - true - } - - fn range_equals( - &self, - other: &dyn Array, - start_idx: usize, - end_idx: usize, - other_start_idx: usize, - ) -> bool { - assert!(other_start_idx + (end_idx - start_idx) <= other.len()); - let other = other - .as_any() - .downcast_ref::>() - .unwrap(); - - let mut j = other_start_idx; - for i in start_idx..end_idx { - let is_null = self.is_null(i); - let other_is_null = other.is_null(j); - - if is_null != other_is_null { - return false; - } - - if is_null { - continue; - } - - let start_offset = self.value_offset(i).to_usize().unwrap(); - let end_offset = self.value_offset(i + 1).to_usize().unwrap(); - let other_start_offset = other.value_offset(j).to_usize().unwrap(); - let other_end_offset = other.value_offset(j + 1).to_usize().unwrap(); - - if end_offset - start_offset != other_end_offset - other_start_offset { - return false; - } - - let value_buf = self.value_data(); - let other_value_buf = other.value_data(); - let value_data = value_buf.data(); - let other_value_data = other_value_buf.data(); - - if end_offset - start_offset > 0 { - let len = end_offset - start_offset; - if value_data[start_offset..(start_offset + len)] - != other_value_data[other_start_offset..(other_start_offset + len)] - { - return false; - } - } - - j += 1; - } - - true - } -} - -impl ArrayEqual for FixedSizeBinaryArray { - fn equals(&self, other: &dyn Array) -> bool { - if !base_equal(&self.data(), &other.data()) { - return false; - } - - let other = other - .as_any() - .downcast_ref::() - .unwrap(); - - let this = self - .as_any() - .downcast_ref::() - .unwrap(); - - if !value_offset_equal(this, other) { - return false; - } - - // TODO: handle null & length == 0 case? - - let value_buf = self.value_data(); - let other_value_buf = other.value_data(); - let value_data = value_buf.data(); - let other_value_data = other_value_buf.data(); - - if self.null_count() == 0 { - // No offset in both - just do memcmp - if self.offset() == 0 && other.offset() == 0 { - let len = self.value_offset(self.len()) as usize; - return value_data[..len] == other_value_data[..len]; - } else { - let start = self.value_offset(0) as usize; - let other_start = other.value_offset(0) as usize; - let len = (self.value_offset(self.len()) - self.value_offset(0)) as usize; - return value_data[start..(start + len)] - == other_value_data[other_start..(other_start + len)]; - } - } else { - for i in 0..self.len() { - if self.is_null(i) { - continue; - } - - let start = self.value_offset(i) as usize; - let other_start = other.value_offset(i) as usize; - let len = self.value_length() as usize; - if value_data[start..(start + len)] - != other_value_data[other_start..(other_start + len)] - { - return false; - } - } - } - - true - } - - fn range_equals( - &self, - other: &dyn Array, - start_idx: usize, - end_idx: usize, - other_start_idx: usize, - ) -> bool { - assert!(other_start_idx + (end_idx - start_idx) <= other.len()); - let other = other - .as_any() - .downcast_ref::() - .unwrap(); - - let mut j = other_start_idx; - for i in start_idx..end_idx { - let is_null = self.is_null(i); - let other_is_null = other.is_null(j); - - if is_null != other_is_null { - return false; - } - - if is_null { - continue; - } - - let start_offset = self.value_offset(i) as usize; - let end_offset = self.value_offset(i + 1) as usize; - let other_start_offset = other.value_offset(j) as usize; - let other_end_offset = other.value_offset(j + 1) as usize; - - if end_offset - start_offset != other_end_offset - other_start_offset { - return false; - } - - let value_buf = self.value_data(); - let other_value_buf = other.value_data(); - let value_data = value_buf.data(); - let other_value_data = other_value_buf.data(); - - if end_offset - start_offset > 0 { - let len = end_offset - start_offset; - if value_data[start_offset..(start_offset + len)] - != other_value_data[other_start_offset..(other_start_offset + len)] - { - return false; - } - } - - j += 1; - } - - true - } -} - -impl ArrayEqual for StructArray { - fn equals(&self, other: &dyn Array) -> bool { - if !base_equal(&self.data(), &other.data()) { - return false; - } - - let other = other.as_any().downcast_ref::().unwrap(); - - for i in 0..self.len() { - let is_null = self.is_null(i); - let other_is_null = other.is_null(i); - - if is_null != other_is_null { - return false; - } - - if is_null { - continue; - } - for j in 0..self.num_columns() { - if !self.column(j).range_equals(&**other.column(j), i, i + 1, i) { - return false; - } - } - } - - true - } - - fn range_equals( - &self, - other: &dyn Array, - start_idx: usize, - end_idx: usize, - other_start_idx: usize, - ) -> bool { - assert!(other_start_idx + (end_idx - start_idx) <= other.len()); - let other = other.as_any().downcast_ref::().unwrap(); - - let mut j = other_start_idx; - for i in start_idx..end_idx { - let is_null = self.is_null(i); - let other_is_null = other.is_null(i); - - if is_null != other_is_null { - return false; - } - - if is_null { - continue; - } - for k in 0..self.num_columns() { - if !self.column(k).range_equals(&**other.column(k), i, i + 1, j) { - return false; - } - } - - j += 1; - } - - true - } -} - -impl ArrayEqual for UnionArray { - fn equals(&self, _other: &dyn Array) -> bool { - unimplemented!( - "Added to allow UnionArray to implement the Array trait: see ARROW-8576" - ) - } - - fn range_equals( - &self, - _other: &dyn Array, - _start_idx: usize, - _end_idx: usize, - _other_start_idx: usize, - ) -> bool { - unimplemented!( - "Added to allow UnionArray to implement the Array trait: see ARROW-8576" - ) - } -} - -impl ArrayEqual for NullArray { - fn equals(&self, other: &dyn Array) -> bool { - if other.data_type() != &DataType::Null { - return false; - } - - if self.len() != other.len() { - return false; - } - if self.null_count() != other.null_count() { - return false; - } - - true - } - - fn range_equals( - &self, - _other: &dyn Array, - _start_idx: usize, - _end_idx: usize, - _other_start_idx: usize, - ) -> bool { - unimplemented!("Range comparison for null array not yet supported") - } -} - -// Compare if the common basic fields between the two arrays are equal -fn base_equal(this: &ArrayDataRef, other: &ArrayDataRef) -> bool { - if this.data_type() != other.data_type() { - return false; - } - if this.len() != other.len() { - return false; - } - if this.null_count() != other.null_count() { - return false; - } - if this.null_count() > 0 { - let null_bitmap = this.null_bitmap().as_ref().unwrap(); - let other_null_bitmap = other.null_bitmap().as_ref().unwrap(); - let null_buf = null_bitmap.bits.data(); - let other_null_buf = other_null_bitmap.bits.data(); - for i in 0..this.len() { - if bit_util::get_bit(null_buf, i + this.offset()) - != bit_util::get_bit(other_null_buf, i + other.offset()) - { - return false; - } - } - } - true -} - -// Compare if the value offsets are equal between the two list arrays -fn value_offset_equal>( - this: &T, - other: &T, -) -> bool { - // Check if offsets differ - if this.offset() == 0 && other.offset() == 0 { - let offset_data = &this.data_ref().buffers()[0]; - let other_offset_data = &other.data_ref().buffers()[0]; - return offset_data.data()[0..((this.len() + 1) * 4)] - == other_offset_data.data()[0..((other.len() + 1) * 4)]; - } - - // The expensive case - for i in 0..=this.len() { - if this.value_offset_at(i) - this.value_offset_at(0) - != other.value_offset_at(i) - other.value_offset_at(0) - { - return false; - } - } - - true -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::error::Result; - use std::{convert::TryFrom, sync::Arc}; - - #[test] - fn test_primitive_equal() { - let a = Int32Array::from(vec![1, 2, 3]); - let b = Int32Array::from(vec![1, 2, 3]); - assert!(a.equals(&b)); - assert!(b.equals(&a)); - - let b = Int32Array::from(vec![1, 2, 4]); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - // Test the case where null_count > 0 - - let a = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); - let b = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); - assert!(a.equals(&b)); - assert!(b.equals(&a)); - - let b = Int32Array::from(vec![Some(1), None, None, Some(3)]); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - let b = Int32Array::from(vec![Some(1), None, Some(2), Some(4)]); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - // Test the case where offset != 0 - - let a_slice = a.slice(1, 2); - let b_slice = b.slice(1, 2); - assert!(a_slice.equals(&*b_slice)); - assert!(b_slice.equals(&*a_slice)); - } - - #[test] - fn test_boolean_equal() { - let a = BooleanArray::from(vec![false, false, true]); - let b = BooleanArray::from(vec![false, false, true]); - assert!(a.equals(&b)); - assert!(b.equals(&a)); - - let b = BooleanArray::from(vec![false, false, false]); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - // Test the case where null_count > 0 - - let a = BooleanArray::from(vec![Some(false), None, None, Some(true)]); - let b = BooleanArray::from(vec![Some(false), None, None, Some(true)]); - assert!(a.equals(&b)); - assert!(b.equals(&a)); - - let b = BooleanArray::from(vec![None, None, None, Some(true)]); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - let b = BooleanArray::from(vec![Some(true), None, None, Some(true)]); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - // Test the case where offset != 0 - - let a = BooleanArray::from(vec![false, true, false, true, false, false, true]); - let b = BooleanArray::from(vec![false, false, false, true, false, true, true]); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - let a_slice = a.slice(2, 3); - let b_slice = b.slice(2, 3); - assert!(a_slice.equals(&*b_slice)); - assert!(b_slice.equals(&*a_slice)); - - let a_slice = a.slice(3, 4); - let b_slice = b.slice(3, 4); - assert!(!a_slice.equals(&*b_slice)); - assert!(!b_slice.equals(&*a_slice)); - } - - #[test] - fn test_list_equal() { - let mut a_builder = ListBuilder::new(Int32Builder::new(10)); - let mut b_builder = ListBuilder::new(Int32Builder::new(10)); - - let a = create_list_array(&mut a_builder, &[Some(&[1, 2, 3]), Some(&[4, 5, 6])]) - .unwrap(); - let b = create_list_array(&mut b_builder, &[Some(&[1, 2, 3]), Some(&[4, 5, 6])]) - .unwrap(); - - assert!(a.equals(&b)); - assert!(b.equals(&a)); - - let b = create_list_array(&mut a_builder, &[Some(&[1, 2, 3]), Some(&[4, 5, 7])]) - .unwrap(); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - // Test the case where null_count > 0 - - let a = create_list_array( - &mut a_builder, - &[Some(&[1, 2]), None, None, Some(&[3, 4]), None, None], - ) - .unwrap(); - let b = create_list_array( - &mut a_builder, - &[Some(&[1, 2]), None, None, Some(&[3, 4]), None, None], - ) - .unwrap(); - assert!(a.equals(&b)); - assert!(b.equals(&a)); - - let b = create_list_array( - &mut a_builder, - &[ - Some(&[1, 2]), - None, - Some(&[5, 6]), - Some(&[3, 4]), - None, - None, - ], - ) - .unwrap(); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - let b = create_list_array( - &mut a_builder, - &[Some(&[1, 2]), None, None, Some(&[3, 5]), None, None], - ) - .unwrap(); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - // Test the case where offset != 0 - - let a_slice = a.slice(0, 3); - let b_slice = b.slice(0, 3); - assert!(a_slice.equals(&*b_slice)); - assert!(b_slice.equals(&*a_slice)); - - let a_slice = a.slice(0, 5); - let b_slice = b.slice(0, 5); - assert!(!a_slice.equals(&*b_slice)); - assert!(!b_slice.equals(&*a_slice)); - - let a_slice = a.slice(4, 1); - let b_slice = b.slice(4, 1); - assert!(a_slice.equals(&*b_slice)); - assert!(b_slice.equals(&*a_slice)); - } - - #[test] - fn test_fixed_size_list_equal() { - let mut a_builder = FixedSizeListBuilder::new(Int32Builder::new(10), 3); - let mut b_builder = FixedSizeListBuilder::new(Int32Builder::new(10), 3); - - let a = create_fixed_size_list_array( - &mut a_builder, - &[Some(&[1, 2, 3]), Some(&[4, 5, 6])], - ) - .unwrap(); - let b = create_fixed_size_list_array( - &mut b_builder, - &[Some(&[1, 2, 3]), Some(&[4, 5, 6])], - ) - .unwrap(); - - assert!(a.equals(&b)); - assert!(b.equals(&a)); - - let b = create_fixed_size_list_array( - &mut a_builder, - &[Some(&[1, 2, 3]), Some(&[4, 5, 7])], - ) - .unwrap(); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - // Test the case where null_count > 0 - - let a = create_fixed_size_list_array( - &mut a_builder, - &[Some(&[1, 2, 3]), None, None, Some(&[4, 5, 6]), None, None], - ) - .unwrap(); - let b = create_fixed_size_list_array( - &mut a_builder, - &[Some(&[1, 2, 3]), None, None, Some(&[4, 5, 6]), None, None], - ) - .unwrap(); - assert!(a.equals(&b)); - assert!(b.equals(&a)); - - let b = create_fixed_size_list_array( - &mut a_builder, - &[ - Some(&[1, 2, 3]), - None, - Some(&[7, 8, 9]), - Some(&[4, 5, 6]), - None, - None, - ], - ) - .unwrap(); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - let b = create_fixed_size_list_array( - &mut a_builder, - &[Some(&[1, 2, 3]), None, None, Some(&[3, 6, 9]), None, None], - ) - .unwrap(); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - // Test the case where offset != 0 - - let a_slice = a.slice(0, 3); - let b_slice = b.slice(0, 3); - assert!(a_slice.equals(&*b_slice)); - assert!(b_slice.equals(&*a_slice)); - - // let a_slice = a.slice(0, 5); - // let b_slice = b.slice(0, 5); - // assert!(!a_slice.equals(&*b_slice)); - // assert!(!b_slice.equals(&*a_slice)); - - // let a_slice = a.slice(4, 1); - // let b_slice = b.slice(4, 1); - // assert!(a_slice.equals(&*b_slice)); - // assert!(b_slice.equals(&*a_slice)); - } - - fn test_generic_string_equal() { - let a = GenericStringArray::::from_vec(vec!["hello", "world"]); - let b = GenericStringArray::::from_vec(vec!["hello", "world"]); - assert!(a.equals(&b)); - assert!(b.equals(&a)); - - let b = GenericStringArray::::from_vec(vec!["hello", "arrow"]); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - // Test the case where null_count > 0 - - let a = GenericStringArray::::from_opt_vec(vec![ - Some("hello"), - None, - None, - Some("world"), - None, - None, - ]); - - let b = GenericStringArray::::from_opt_vec(vec![ - Some("hello"), - None, - None, - Some("world"), - None, - None, - ]); - assert!(a.equals(&b)); - assert!(b.equals(&a)); - - let b = GenericStringArray::::from_opt_vec(vec![ - Some("hello"), - Some("foo"), - None, - Some("world"), - None, - None, - ]); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - let b = GenericStringArray::::from_opt_vec(vec![ - Some("hello"), - None, - None, - Some("arrow"), - None, - None, - ]); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - // Test the case where offset != 0 - - let a_slice = a.slice(0, 3); - let b_slice = b.slice(0, 3); - assert!(a_slice.equals(&*b_slice)); - assert!(b_slice.equals(&*a_slice)); - - let a_slice = a.slice(0, 5); - let b_slice = b.slice(0, 5); - assert!(!a_slice.equals(&*b_slice)); - assert!(!b_slice.equals(&*a_slice)); - - let a_slice = a.slice(4, 1); - let b_slice = b.slice(4, 1); - assert!(a_slice.equals(&*b_slice)); - assert!(b_slice.equals(&*a_slice)); - } - - #[test] - fn test_string_equal() { - test_generic_string_equal::() - } - - #[test] - fn test_large_string_equal() { - test_generic_string_equal::() - } - - #[test] - fn test_struct_equal() { - let strings: ArrayRef = Arc::new(StringArray::from(vec![ - Some("joe"), - None, - None, - Some("mark"), - Some("doe"), - ])); - let ints: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - Some(2), - None, - Some(4), - Some(5), - ])); - - let a = - StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) - .unwrap(); - - let b = StructArray::try_from(vec![("f1", strings), ("f2", ints)]).unwrap(); - - assert!(a.equals(&b)); - assert!(b.equals(&a)); - } - - #[test] - fn test_null_equal() { - let a = NullArray::new(12); - let b = NullArray::new(12); - assert!(a.equals(&b)); - assert!(b.equals(&a)); - - let b = NullArray::new(10); - assert!(!a.equals(&b)); - assert!(!b.equals(&a)); - - // Test the case where offset != 0 - - let a_slice = a.slice(2, 3); - let b_slice = b.slice(1, 3); - assert!(a_slice.equals(&*b_slice)); - assert!(b_slice.equals(&*a_slice)); - - let a_slice = a.slice(5, 4); - let b_slice = b.slice(3, 3); - assert!(!a_slice.equals(&*b_slice)); - assert!(!b_slice.equals(&*a_slice)); - } - - fn create_list_array<'a, U: AsRef<[i32]>, T: AsRef<[Option]>>( - builder: &'a mut ListBuilder, - data: T, - ) -> Result { - for d in data.as_ref() { - if let Some(v) = d { - builder.values().append_slice(v.as_ref())?; - builder.append(true)? - } else { - builder.append(false)? - } - } - Ok(builder.finish()) - } - - /// Create a fixed size list of 2 value lengths - fn create_fixed_size_list_array<'a, U: AsRef<[i32]>, T: AsRef<[Option]>>( - builder: &'a mut FixedSizeListBuilder, - data: T, - ) -> Result { - for d in data.as_ref() { - if let Some(v) = d { - builder.values().append_slice(v.as_ref())?; - builder.append(true)? - } else { - for _ in 0..builder.value_length() { - builder.values().append_null()?; - } - builder.append(false)? - } - } - Ok(builder.finish()) - } -} diff --git a/rust/arrow/src/array/equal/boolean.rs b/rust/arrow/src/array/equal/boolean.rs new file mode 100644 index 00000000000..4158080b81d --- /dev/null +++ b/rust/arrow/src/array/equal/boolean.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::ArrayData; + +use super::utils::equal_bits; + +pub(super) fn boolean_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_values = lhs.buffers()[0].data(); + let rhs_values = rhs.buffers()[0].data(); + + // TODO: we can do this more efficiently if all values are not-null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + let lhs_is_null = lhs.is_null(lhs_pos); + let rhs_is_null = rhs.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_bits( + lhs_values, + rhs_values, + lhs_pos + lhs.offset(), + rhs_pos + rhs.offset(), + 1, + ) + }) +} diff --git a/rust/arrow/src/array/equal/dictionary.rs b/rust/arrow/src/array/equal/dictionary.rs new file mode 100644 index 00000000000..a41b0a9b74e --- /dev/null +++ b/rust/arrow/src/array/equal/dictionary.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{array::ArrayData, datatypes::ArrowNativeType}; + +use super::equal_range; + +pub(super) fn dictionary_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_keys = lhs.buffer::(0); + let rhs_keys = rhs.buffer::(0); + + let lhs_values = lhs.child_data()[0].as_ref(); + let rhs_values = rhs.child_data()[0].as_ref(); + + if lhs.null_count() == 0 && rhs.null_count() == 0 { + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + + equal_range( + lhs_values, + rhs_values, + lhs_keys[lhs_pos].to_usize().unwrap(), + rhs_keys[rhs_pos].to_usize().unwrap(), + 1, + ) + }) + } else { + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + + let lhs_is_null = lhs.is_null(lhs_pos); + let rhs_is_null = rhs.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_range( + lhs_values, + rhs_values, + lhs_keys[lhs_pos].to_usize().unwrap(), + rhs_keys[rhs_pos].to_usize().unwrap(), + 1, + ) + }) + } +} diff --git a/rust/arrow/src/array/equal/fixed_binary.rs b/rust/arrow/src/array/equal/fixed_binary.rs new file mode 100644 index 00000000000..e0fdf07ec85 --- /dev/null +++ b/rust/arrow/src/array/equal/fixed_binary.rs @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{array::ArrayData, datatypes::DataType}; + +use super::utils::equal_len; + +pub(super) fn fixed_binary_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let size = match lhs.data_type() { + DataType::FixedSizeBinary(i) => *i as usize, + _ => unreachable!(), + }; + + let lhs_values = lhs.buffer::(0); + let rhs_values = rhs.buffer::(0); + + if lhs.null_count() == 0 && rhs.null_count() == 0 { + equal_len( + lhs_values, + rhs_values, + size * lhs_start, + size * rhs_start, + size * len, + ) + } else { + // with nulls, we need to compare item by item whenever it is not null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + + let lhs_is_null = lhs.is_null(lhs_pos); + let rhs_is_null = rhs.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_len( + lhs_values, + rhs_values, + lhs_pos * size, + rhs_pos * size, + size, // 1 * size since we are comparing a single entry + ) + }) + } +} diff --git a/rust/arrow/src/array/equal/fixed_list.rs b/rust/arrow/src/array/equal/fixed_list.rs new file mode 100644 index 00000000000..aeb0d1372c8 --- /dev/null +++ b/rust/arrow/src/array/equal/fixed_list.rs @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{array::ArrayData, datatypes::DataType}; + +use super::equal_range; + +pub(super) fn fixed_list_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let size = match lhs.data_type() { + DataType::FixedSizeList(_, i) => *i as usize, + _ => unreachable!(), + }; + + let lhs_values = lhs.child_data()[0].as_ref(); + let rhs_values = rhs.child_data()[0].as_ref(); + + if lhs.null_count() == 0 && rhs.null_count() == 0 { + equal_range( + lhs_values, + rhs_values, + size * lhs_start, + size * rhs_start, + size * len, + ) + } else { + // with nulls, we need to compare item by item whenever it is not null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + + let lhs_is_null = lhs.is_null(lhs_pos); + let rhs_is_null = rhs.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_range( + lhs_values, + rhs_values, + lhs_pos * size, + rhs_pos * size, + size, // 1 * size since we are comparing a single entry + ) + }) + } +} diff --git a/rust/arrow/src/array/equal/list.rs b/rust/arrow/src/array/equal/list.rs new file mode 100644 index 00000000000..7e81a342443 --- /dev/null +++ b/rust/arrow/src/array/equal/list.rs @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{array::ArrayData, array::OffsetSizeTrait}; + +use super::equal_range; + +fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { + // invariant from `base_equal` + debug_assert_eq!(lhs.len(), rhs.len()); + + if lhs.len() == 0 { + return true; + } + + if lhs[0] == T::zero() && rhs[0] == T::zero() { + return lhs == rhs; + }; + + // The expensive case, e.g. + // [0, 2, 4, 6, 9] == [4, 6, 8, 10, 13] + lhs.windows(2) + .zip(rhs.windows(2)) + .all(|(lhs_offsets, rhs_offsets)| { + // length of left == length of right + (lhs_offsets[1] - lhs_offsets[0]) == (rhs_offsets[1] - rhs_offsets[0]) + }) +} + +#[inline] +fn offset_value_equal( + lhs_values: &ArrayData, + rhs_values: &ArrayData, + lhs_offsets: &[T], + rhs_offsets: &[T], + lhs_pos: usize, + rhs_pos: usize, + len: usize, +) -> bool { + let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap(); + let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap(); + let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos]; + let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos]; + + lhs_len == rhs_len + && equal_range( + lhs_values, + rhs_values, + lhs_start, + rhs_start, + lhs_len.to_usize().unwrap(), + ) +} + +pub(super) fn list_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_offsets = lhs.buffer::(0); + let rhs_offsets = rhs.buffer::(0); + + let lhs_values = lhs.child_data()[0].as_ref(); + let rhs_values = rhs.child_data()[0].as_ref(); + + if lhs.null_count() == 0 && rhs.null_count() == 0 { + lengths_equal( + &lhs_offsets[lhs_start..lhs_start + len], + &rhs_offsets[rhs_start..rhs_start + len], + ) && equal_range( + lhs_values, + rhs_values, + lhs_offsets[lhs_start].to_usize().unwrap(), + rhs_offsets[rhs_start].to_usize().unwrap(), + (lhs_offsets[len] - lhs_offsets[lhs_start]) + .to_usize() + .unwrap(), + ) + } else { + // with nulls, we need to compare item by item whenever it is not null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + + let lhs_is_null = lhs.is_null(lhs_pos); + let rhs_is_null = rhs.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && offset_value_equal::( + lhs_values, + rhs_values, + lhs_offsets, + rhs_offsets, + lhs_pos, + rhs_pos, + 1, + ) + }) + } +} diff --git a/rust/arrow/src/array/equal/mod.rs b/rust/arrow/src/array/equal/mod.rs new file mode 100644 index 00000000000..400a6e24cf7 --- /dev/null +++ b/rust/arrow/src/array/equal/mod.rs @@ -0,0 +1,831 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing functionality to compute array equality. +//! This module uses [ArrayData] and does not +//! depend on dynamic casting of `Array`. + +use super::{ + array::BinaryOffsetSizeTrait, Array, ArrayData, FixedSizeBinaryArray, + GenericBinaryArray, GenericListArray, GenericStringArray, OffsetSizeTrait, + PrimitiveArray, StringOffsetSizeTrait, StructArray, +}; + +use crate::datatypes::{ArrowPrimitiveType, DataType, IntervalUnit}; + +mod boolean; +mod dictionary; +mod fixed_binary; +mod fixed_list; +mod list; +mod null; +mod primitive; +mod structure; +mod utils; +mod variable_size; + +// these methods assume the same type, len and null count. +// For this reason, they are not exposed and are instead used +// to build the generic functions below (`equal_range` and `equal`). +use boolean::boolean_equal; +use dictionary::dictionary_equal; +use fixed_binary::fixed_binary_equal; +use fixed_list::fixed_list_equal; +use list::list_equal; +use null::null_equal; +use primitive::primitive_equal; +use structure::struct_equal; +use variable_size::variable_sized_equal; + +impl PartialEq for dyn Array { + fn eq(&self, other: &Self) -> bool { + equal(self.data().as_ref(), other.data().as_ref()) + } +} + +impl PartialEq for dyn Array { + fn eq(&self, other: &T) -> bool { + equal(self.data().as_ref(), other.data().as_ref()) + } +} + +impl PartialEq for PrimitiveArray { + fn eq(&self, other: &PrimitiveArray) -> bool { + equal(self.data().as_ref(), other.data().as_ref()) + } +} + +impl PartialEq for GenericStringArray { + fn eq(&self, other: &Self) -> bool { + equal(self.data().as_ref(), other.data().as_ref()) + } +} + +impl PartialEq for GenericBinaryArray { + fn eq(&self, other: &Self) -> bool { + equal(self.data().as_ref(), other.data().as_ref()) + } +} + +impl PartialEq for FixedSizeBinaryArray { + fn eq(&self, other: &Self) -> bool { + equal(self.data().as_ref(), other.data().as_ref()) + } +} + +impl PartialEq for GenericListArray { + fn eq(&self, other: &Self) -> bool { + equal(self.data().as_ref(), other.data().as_ref()) + } +} + +impl PartialEq for StructArray { + fn eq(&self, other: &Self) -> bool { + equal(self.data().as_ref(), other.data().as_ref()) + } +} + +/// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively +/// for `len` slots. +#[inline] +fn equal_values( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + match lhs.data_type() { + DataType::Null => null_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Boolean => boolean_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt8 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int8 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Float32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Float64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Date32(_) + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Date64(_) + | DataType::Interval(IntervalUnit::DayTime) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Utf8 | DataType::Binary => { + variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::LargeUtf8 | DataType::LargeBinary => { + variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::FixedSizeBinary(_) => { + fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::List(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::LargeList(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::FixedSizeList(_, _) => { + fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Union(_) => unimplemented!("See ARROW-8576"), + DataType::Dictionary(data_type, _) => match data_type.as_ref() { + DataType::Int8 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int16 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Int32 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Int64 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt8 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt16 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt32 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt64 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + _ => unreachable!(), + }, + DataType::Float16 => unreachable!(), + } +} + +fn equal_range( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + utils::base_equal(lhs, rhs) + && utils::equal_nulls(lhs, rhs, lhs_start, rhs_start, len) + && equal_values(lhs, rhs, lhs_start, rhs_start, len) +} + +/// Logically compares two [ArrayData]. +/// Two arrays are logically equal if and only if: +/// * their data types are equal +/// * their lenghts are equal +/// * their null counts are equal +/// * their null bitmaps are equal +/// * each of their items are equal +/// two items are equal when their in-memory representation is physically equal (i.e. same bit content). +/// The physical comparison depend on the data type. +/// # Panics +/// This function may panic whenever any of the [ArrayData] does not follow the Arrow specification. +/// (e.g. wrong number of buffers, buffer `len` does not correspond to the declared `len`) +pub fn equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { + utils::base_equal(lhs, rhs) + && lhs.null_count() == rhs.null_count() + && utils::equal_nulls(lhs, rhs, 0, 0, lhs.len()) + && equal_values(lhs, rhs, 0, 0, lhs.len()) +} + +#[cfg(test)] +mod tests { + use std::convert::TryFrom; + use std::sync::Arc; + + use crate::array::{ + array::Array, array::BinaryOffsetSizeTrait, ArrayDataRef, ArrayRef, BooleanArray, + FixedSizeBinaryBuilder, FixedSizeListBuilder, GenericBinaryArray, Int32Builder, + ListBuilder, NullArray, PrimitiveBuilder, StringArray, StringDictionaryBuilder, + StringOffsetSizeTrait, StructArray, + }; + use crate::array::{GenericStringArray, Int32Array}; + use crate::datatypes::Int16Type; + + use super::*; + + #[test] + fn test_null_equal() { + let a = NullArray::new(12).data(); + let b = NullArray::new(12).data(); + test_equal(&a, &b, true); + + let b = NullArray::new(10).data(); + test_equal(&a, &b, false); + + // Test the case where offset != 0 + + let a_slice = a.slice(2, 3); + let b_slice = b.slice(1, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.slice(5, 4); + let b_slice = b.slice(3, 3); + test_equal(&a_slice, &b_slice, false); + } + + #[test] + fn test_boolean_equal() { + let a = BooleanArray::from(vec![false, false, true]).data(); + let b = BooleanArray::from(vec![false, false, true]).data(); + test_equal(a.as_ref(), b.as_ref(), true); + + let b = BooleanArray::from(vec![false, false, false]).data(); + test_equal(a.as_ref(), b.as_ref(), false); + + // Test the case where null_count > 0 + + let a = BooleanArray::from(vec![Some(false), None, None, Some(true)]).data(); + let b = BooleanArray::from(vec![Some(false), None, None, Some(true)]).data(); + test_equal(a.as_ref(), b.as_ref(), true); + + let b = BooleanArray::from(vec![None, None, None, Some(true)]).data(); + test_equal(a.as_ref(), b.as_ref(), false); + + let b = BooleanArray::from(vec![Some(true), None, None, Some(true)]).data(); + test_equal(a.as_ref(), b.as_ref(), false); + + // Test the case where offset != 0 + + let a = + BooleanArray::from(vec![false, true, false, true, false, false, true]).data(); + let b = + BooleanArray::from(vec![false, false, false, true, false, true, true]).data(); + assert_eq!(equal(a.as_ref(), b.as_ref()), false); + assert_eq!(equal(b.as_ref(), a.as_ref()), false); + + let a_slice = a.slice(2, 3); + let b_slice = b.slice(2, 3); + assert_eq!(equal(&a_slice, &b_slice), true); + assert_eq!(equal(&b_slice, &a_slice), true); + + let a_slice = a.slice(3, 4); + let b_slice = b.slice(3, 4); + assert_eq!(equal(&a_slice, &b_slice), false); + assert_eq!(equal(&b_slice, &a_slice), false); + } + + #[test] + fn test_primitive() { + let cases = vec![ + ( + vec![Some(1), Some(2), Some(3)], + vec![Some(1), Some(2), Some(3)], + true, + ), + ( + vec![Some(1), Some(2), Some(3)], + vec![Some(1), Some(2), Some(4)], + false, + ), + ( + vec![Some(1), Some(2), None], + vec![Some(1), Some(2), None], + true, + ), + ( + vec![Some(1), None, Some(3)], + vec![Some(1), Some(2), None], + false, + ), + ( + vec![Some(1), None, None], + vec![Some(1), Some(2), None], + false, + ), + ]; + + for (lhs, rhs, expected) in cases { + let lhs = Int32Array::from(lhs).data(); + let rhs = Int32Array::from(rhs).data(); + test_equal(&lhs, &rhs, expected); + } + } + + #[test] + fn test_primitive_slice() { + let cases = vec![ + ( + vec![Some(1), Some(2), Some(3)], + (0, 1), + vec![Some(1), Some(2), Some(3)], + (0, 1), + true, + ), + ( + vec![Some(1), Some(2), Some(3)], + (1, 1), + vec![Some(1), Some(2), Some(3)], + (2, 1), + false, + ), + ( + vec![Some(1), Some(2), None], + (1, 1), + vec![Some(1), None, Some(2)], + (2, 1), + true, + ), + ]; + + for (lhs, slice_lhs, rhs, slice_rhs, expected) in cases { + let lhs = Int32Array::from(lhs).data(); + let lhs = lhs.slice(slice_lhs.0, slice_lhs.1); + let rhs = Int32Array::from(rhs).data(); + let rhs = rhs.slice(slice_rhs.0, slice_rhs.1); + + test_equal(&lhs, &rhs, expected); + } + } + + fn test_equal(lhs: &ArrayData, rhs: &ArrayData, expected: bool) { + // equality is symetric + assert_eq!(equal(lhs, lhs), true, "\n{:?}\n{:?}", lhs, lhs); + assert_eq!(equal(rhs, rhs), true, "\n{:?}\n{:?}", rhs, rhs); + + assert_eq!(equal(lhs, rhs), expected, "\n{:?}\n{:?}", lhs, rhs); + assert_eq!(equal(rhs, lhs), expected, "\n{:?}\n{:?}", rhs, lhs); + } + + fn binary_cases() -> Vec<(Vec>, Vec>, bool)> { + let base = vec![ + Some("hello".to_owned()), + None, + None, + Some("world".to_owned()), + None, + None, + ]; + let not_base = vec![ + Some("hello".to_owned()), + Some("foo".to_owned()), + None, + Some("world".to_owned()), + None, + None, + ]; + vec![ + ( + vec![Some("hello".to_owned()), Some("world".to_owned())], + vec![Some("hello".to_owned()), Some("world".to_owned())], + true, + ), + ( + vec![Some("hello".to_owned()), Some("world".to_owned())], + vec![Some("hello".to_owned()), Some("arrow".to_owned())], + false, + ), + (base.clone(), base.clone(), true), + (base.clone(), not_base.clone(), false), + ] + } + + fn test_generic_string_equal() { + let cases = binary_cases(); + + for (lhs, rhs, expected) in cases { + let lhs = lhs.iter().map(|x| x.as_deref()).collect(); + let rhs = rhs.iter().map(|x| x.as_deref()).collect(); + let lhs = GenericStringArray::::from_opt_vec(lhs).data(); + let rhs = GenericStringArray::::from_opt_vec(rhs).data(); + test_equal(lhs.as_ref(), rhs.as_ref(), expected); + } + } + + #[test] + fn test_string_equal() { + test_generic_string_equal::() + } + + #[test] + fn test_large_string_equal() { + test_generic_string_equal::() + } + + fn test_generic_binary_equal() { + let cases = binary_cases(); + + for (lhs, rhs, expected) in cases { + let lhs = lhs + .iter() + .map(|x| x.as_deref().map(|x| x.as_bytes())) + .collect(); + let rhs = rhs + .iter() + .map(|x| x.as_deref().map(|x| x.as_bytes())) + .collect(); + let lhs = GenericBinaryArray::::from_opt_vec(lhs).data(); + let rhs = GenericBinaryArray::::from_opt_vec(rhs).data(); + test_equal(lhs.as_ref(), rhs.as_ref(), expected); + } + } + + #[test] + fn test_binary_equal() { + test_generic_binary_equal::() + } + + #[test] + fn test_large_binary_equal() { + test_generic_binary_equal::() + } + + #[test] + fn test_null() { + let a = NullArray::new(2).data(); + let b = NullArray::new(2).data(); + test_equal(a.as_ref(), b.as_ref(), true); + + let b = NullArray::new(1).data(); + test_equal(a.as_ref(), b.as_ref(), false); + } + + fn create_list_array, T: AsRef<[Option]>>( + data: T, + ) -> ArrayDataRef { + let mut builder = ListBuilder::new(Int32Builder::new(10)); + for d in data.as_ref() { + if let Some(v) = d { + builder.values().append_slice(v.as_ref()).unwrap(); + builder.append(true).unwrap() + } else { + builder.append(false).unwrap() + } + } + builder.finish().data() + } + + #[test] + fn test_list_equal() { + let a = create_list_array(&[Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + let b = create_list_array(&[Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + test_equal(a.as_ref(), b.as_ref(), true); + + let b = create_list_array(&[Some(&[1, 2, 3]), Some(&[4, 5, 7])]); + test_equal(a.as_ref(), b.as_ref(), false); + } + + // Test the case where null_count > 0 + #[test] + fn test_list_null() { + let a = + create_list_array(&[Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + let b = + create_list_array(&[Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + test_equal(a.as_ref(), b.as_ref(), true); + + let b = create_list_array(&[ + Some(&[1, 2]), + None, + Some(&[5, 6]), + Some(&[3, 4]), + None, + None, + ]); + test_equal(a.as_ref(), b.as_ref(), false); + + let b = + create_list_array(&[Some(&[1, 2]), None, None, Some(&[3, 5]), None, None]); + test_equal(a.as_ref(), b.as_ref(), false); + } + + // Test the case where offset != 0 + #[test] + fn test_list_offsets() { + let a = + create_list_array(&[Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + let b = + create_list_array(&[Some(&[1, 2]), None, None, Some(&[3, 5]), None, None]); + + let a_slice = a.slice(0, 3); + let b_slice = b.slice(0, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.slice(0, 5); + let b_slice = b.slice(0, 5); + test_equal(&a_slice, &b_slice, false); + + let a_slice = a.slice(4, 1); + let b_slice = b.slice(4, 1); + test_equal(&a_slice, &b_slice, true); + } + + fn create_fixed_size_binary_array, T: AsRef<[Option]>>( + data: T, + ) -> ArrayDataRef { + let mut builder = FixedSizeBinaryBuilder::new(15, 5); + + for d in data.as_ref() { + if let Some(v) = d { + builder.append_value(v.as_ref()).unwrap(); + } else { + builder.append_null().unwrap(); + } + } + builder.finish().data() + } + + #[test] + fn test_fixed_size_binary_equal() { + let a = create_fixed_size_binary_array(&[Some(b"hello"), Some(b"world")]); + let b = create_fixed_size_binary_array(&[Some(b"hello"), Some(b"world")]); + test_equal(a.as_ref(), b.as_ref(), true); + + let b = create_fixed_size_binary_array(&[Some(b"hello"), Some(b"arrow")]); + test_equal(a.as_ref(), b.as_ref(), false); + } + + // Test the case where null_count > 0 + #[test] + fn test_fixed_size_binary_null() { + let a = create_fixed_size_binary_array(&[Some(b"hello"), None, Some(b"world")]); + let b = create_fixed_size_binary_array(&[Some(b"hello"), None, Some(b"world")]); + test_equal(a.as_ref(), b.as_ref(), true); + + let b = create_fixed_size_binary_array(&[Some(b"hello"), Some(b"world"), None]); + test_equal(a.as_ref(), b.as_ref(), false); + + let b = create_fixed_size_binary_array(&[Some(b"hello"), None, Some(b"arrow")]); + test_equal(a.as_ref(), b.as_ref(), false); + } + + #[test] + fn test_fixed_size_binary_offsets() { + // Test the case where offset != 0 + let a = create_fixed_size_binary_array(&[ + Some(b"hello"), + None, + None, + Some(b"world"), + None, + None, + ]); + let b = create_fixed_size_binary_array(&[ + Some(b"hello"), + None, + None, + Some(b"arrow"), + None, + None, + ]); + + let a_slice = a.slice(0, 3); + let b_slice = b.slice(0, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.slice(0, 5); + let b_slice = b.slice(0, 5); + test_equal(&a_slice, &b_slice, false); + + let a_slice = a.slice(4, 1); + let b_slice = b.slice(4, 1); + test_equal(&a_slice, &b_slice, true); + } + + /// Create a fixed size list of 2 value lengths + fn create_fixed_size_list_array, T: AsRef<[Option]>>( + data: T, + ) -> ArrayDataRef { + let mut builder = FixedSizeListBuilder::new(Int32Builder::new(10), 3); + + for d in data.as_ref() { + if let Some(v) = d { + builder.values().append_slice(v.as_ref()).unwrap(); + builder.append(true).unwrap() + } else { + for _ in 0..builder.value_length() { + builder.values().append_null().unwrap(); + } + builder.append(false).unwrap() + } + } + builder.finish().data() + } + + #[test] + fn test_fixed_size_list_equal() { + let a = create_fixed_size_list_array(&[Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + let b = create_fixed_size_list_array(&[Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + test_equal(a.as_ref(), b.as_ref(), true); + + let b = create_fixed_size_list_array(&[Some(&[1, 2, 3]), Some(&[4, 5, 7])]); + test_equal(a.as_ref(), b.as_ref(), false); + } + + // Test the case where null_count > 0 + #[test] + fn test_fixed_list_null() { + let a = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + None, + Some(&[4, 5, 6]), + None, + None, + ]); + let b = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + None, + Some(&[4, 5, 6]), + None, + None, + ]); + test_equal(a.as_ref(), b.as_ref(), true); + + let b = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + Some(&[7, 8, 9]), + Some(&[4, 5, 6]), + None, + None, + ]); + test_equal(a.as_ref(), b.as_ref(), false); + + let b = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + None, + Some(&[3, 6, 9]), + None, + None, + ]); + test_equal(a.as_ref(), b.as_ref(), false); + } + + #[test] + fn test_fixed_list_offsets() { + // Test the case where offset != 0 + let a = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + None, + Some(&[4, 5, 6]), + None, + None, + ]); + let b = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + None, + Some(&[3, 6, 9]), + None, + None, + ]); + + let a_slice = a.slice(0, 3); + let b_slice = b.slice(0, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.slice(0, 5); + let b_slice = b.slice(0, 5); + test_equal(&a_slice, &b_slice, false); + + let a_slice = a.slice(4, 1); + let b_slice = b.slice(4, 1); + test_equal(&a_slice, &b_slice, true); + } + + #[test] + fn test_struct_equal() { + let strings: ArrayRef = Arc::new(StringArray::from(vec![ + Some("joe"), + None, + None, + Some("mark"), + Some("doe"), + ])); + let ints: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + ])); + + let a = + StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) + .unwrap() + .data(); + + let b = StructArray::try_from(vec![("f1", strings), ("f2", ints)]) + .unwrap() + .data(); + + test_equal(a.as_ref(), b.as_ref(), true); + } + + fn create_dictionary_array(values: &[&str], keys: &[Option<&str>]) -> ArrayDataRef { + let values = StringArray::from(values.to_vec()); + let mut builder = StringDictionaryBuilder::new_with_dictionary( + PrimitiveBuilder::::new(3), + &values, + ) + .unwrap(); + for key in keys { + if let Some(v) = key { + builder.append(v).unwrap(); + } else { + builder.append_null().unwrap() + } + } + builder.finish().data() + } + + #[test] + fn test_dictionary_equal() { + // (a, b, c), (1, 2, 1, 3) => (a, b, a, c) + let a = create_dictionary_array( + &["a", "b", "c"], + &[Some("a"), Some("b"), Some("a"), Some("c")], + ); + // different representation (values and keys are swapped), same result + let b = create_dictionary_array( + &["a", "c", "b"], + &[Some("a"), Some("b"), Some("a"), Some("c")], + ); + test_equal(a.as_ref(), b.as_ref(), true); + + // different len + let b = + create_dictionary_array(&["a", "c", "b"], &[Some("a"), Some("b"), Some("a")]); + test_equal(a.as_ref(), b.as_ref(), false); + + // different key + let b = create_dictionary_array( + &["a", "c", "b"], + &[Some("a"), Some("b"), Some("a"), Some("a")], + ); + test_equal(a.as_ref(), b.as_ref(), false); + + // different values, same keys + let b = create_dictionary_array( + &["a", "b", "d"], + &[Some("a"), Some("b"), Some("a"), Some("d")], + ); + test_equal(a.as_ref(), b.as_ref(), false); + } + + #[test] + fn test_dictionary_equal_null() { + // (a, b, c), (1, 2, 1, 3) => (a, b, a, c) + let a = create_dictionary_array( + &["a", "b", "c"], + &[Some("a"), None, Some("a"), Some("c")], + ); + + // equal to self + test_equal(a.as_ref(), a.as_ref(), true); + + // different representation (values and keys are swapped), same result + let b = create_dictionary_array( + &["a", "c", "b"], + &[Some("a"), None, Some("a"), Some("c")], + ); + test_equal(a.as_ref(), b.as_ref(), true); + + // different null position + let b = create_dictionary_array( + &["a", "c", "b"], + &[Some("a"), Some("b"), Some("a"), None], + ); + test_equal(a.as_ref(), b.as_ref(), false); + + // different key + let b = create_dictionary_array( + &["a", "c", "b"], + &[Some("a"), None, Some("a"), Some("a")], + ); + test_equal(a.as_ref(), b.as_ref(), false); + + // different values, same keys + let b = create_dictionary_array( + &["a", "b", "d"], + &[Some("a"), None, Some("a"), Some("d")], + ); + test_equal(a.as_ref(), b.as_ref(), false); + } +} diff --git a/rust/arrow/src/array/equal/null.rs b/rust/arrow/src/array/equal/null.rs new file mode 100644 index 00000000000..f287a382507 --- /dev/null +++ b/rust/arrow/src/array/equal/null.rs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::ArrayData; + +#[inline] +pub(super) fn null_equal( + _lhs: &ArrayData, + _rhs: &ArrayData, + _lhs_start: usize, + _rhs_start: usize, + _len: usize, +) -> bool { + // a null buffer's range is always true, as every element is by definition equal (to null). + // We only need to compare data_types + true +} diff --git a/rust/arrow/src/array/equal/primitive.rs b/rust/arrow/src/array/equal/primitive.rs new file mode 100644 index 00000000000..19602e46488 --- /dev/null +++ b/rust/arrow/src/array/equal/primitive.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; + +use crate::array::ArrayData; + +use super::utils::equal_len; + +pub(super) fn primitive_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let byte_width = size_of::(); + let lhs_values = &lhs.buffers()[0].data()[lhs.offset() * byte_width..]; + let rhs_values = &rhs.buffers()[0].data()[rhs.offset() * byte_width..]; + + if lhs.null_count() == 0 && rhs.null_count() == 0 { + // without nulls, we just need to compare slices + equal_len( + lhs_values, + rhs_values, + lhs_start * byte_width, + rhs_start * byte_width, + len * byte_width, + ) + } else { + // with nulls, we need to compare item by item whenever it is not null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + let lhs_is_null = lhs.is_null(lhs_pos); + let rhs_is_null = rhs.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_len( + lhs_values, + rhs_values, + lhs_pos * byte_width, + rhs_pos * byte_width, + byte_width, // 1 * byte_width since we are comparing a single entry + ) + }) + } +} diff --git a/rust/arrow/src/array/equal/structure.rs b/rust/arrow/src/array/equal/structure.rs new file mode 100644 index 00000000000..1e8a1ff260b --- /dev/null +++ b/rust/arrow/src/array/equal/structure.rs @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::ArrayData; + +use super::equal_range; + +fn equal_values( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + lhs.child_data() + .iter() + .zip(rhs.child_data()) + .all(|(lhs_values, rhs_values)| { + equal_range(lhs_values, rhs_values, lhs_start, rhs_start, len) + }) +} + +pub(super) fn struct_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + if lhs.null_count() == 0 && rhs.null_count() == 0 { + equal_values(lhs, rhs, lhs_start, rhs_start, len) + } else { + // with nulls, we need to compare item by item whenever it is not null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + let lhs_is_null = lhs.is_null(lhs_pos); + let rhs_is_null = rhs.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_values(lhs, rhs, lhs_pos, rhs_pos, 1) + }) + } +} diff --git a/rust/arrow/src/array/equal/utils.rs b/rust/arrow/src/array/equal/utils.rs new file mode 100644 index 00000000000..f9e8860a5bb --- /dev/null +++ b/rust/arrow/src/array/equal/utils.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{array::ArrayData, util::bit_util}; + +// whether bits along the positions are equal +// `lhs_start`, `rhs_start` and `len` are _measured in bits_. +#[inline] +pub(super) fn equal_bits( + lhs_values: &[u8], + rhs_values: &[u8], + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + (0..len).all(|i| { + bit_util::get_bit(lhs_values, lhs_start + i) + == bit_util::get_bit(rhs_values, rhs_start + i) + }) +} + +#[inline] +pub(super) fn equal_nulls( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + if lhs.null_count() > 0 || rhs.null_count() > 0 { + let lhs_null_bitmap = lhs.null_bitmap().as_ref().unwrap(); + let rhs_null_bitmap = rhs.null_bitmap().as_ref().unwrap(); + let lhs_values = lhs_null_bitmap.bits.data(); + let rhs_values = rhs_null_bitmap.bits.data(); + equal_bits( + lhs_values, + rhs_values, + lhs_start + lhs.offset(), + rhs_start + rhs.offset(), + len, + ) + } else { + true + } +} + +#[inline] +pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() +} + +// whether the two memory regions are equal +#[inline] +pub(super) fn equal_len( + lhs_values: &[u8], + rhs_values: &[u8], + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + lhs_values[lhs_start..(lhs_start + len)] == rhs_values[rhs_start..(rhs_start + len)] +} diff --git a/rust/arrow/src/array/equal/variable_size.rs b/rust/arrow/src/array/equal/variable_size.rs new file mode 100644 index 00000000000..237b353d287 --- /dev/null +++ b/rust/arrow/src/array/equal/variable_size.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::{ArrayData, OffsetSizeTrait}; + +use super::utils::equal_len; + +fn offset_value_equal( + lhs_values: &[u8], + rhs_values: &[u8], + lhs_offsets: &[T], + rhs_offsets: &[T], + lhs_pos: usize, + rhs_pos: usize, + len: usize, +) -> bool { + let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap(); + let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap(); + let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos]; + let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos]; + + lhs_len == rhs_len + && equal_len( + lhs_values, + rhs_values, + lhs_start, + rhs_start, + lhs_len.to_usize().unwrap(), + ) +} + +pub(super) fn variable_sized_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_offsets = lhs.buffer::(0); + let rhs_offsets = rhs.buffer::(0); + + // these are bytes, and thus the offset does not need to be multiplied + let lhs_values = &lhs.buffers()[1].data()[lhs.offset()..]; + let rhs_values = &rhs.buffers()[1].data()[rhs.offset()..]; + + if lhs.null_count() == 0 && rhs.null_count() == 0 { + offset_value_equal( + lhs_values, + rhs_values, + lhs_offsets, + rhs_offsets, + lhs_start, + rhs_start, + len, + ) + } else { + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + + let lhs_is_null = lhs.is_null(lhs_pos); + let rhs_is_null = rhs.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && offset_value_equal( + lhs_values, + rhs_values, + lhs_offsets, + rhs_offsets, + lhs_pos, + rhs_pos, + 1, + ) + }) + } +} diff --git a/rust/arrow/src/array/mod.rs b/rust/arrow/src/array/mod.rs index 88d5d74cac4..0c96b7948ed 100644 --- a/rust/arrow/src/array/mod.rs +++ b/rust/arrow/src/array/mod.rs @@ -247,7 +247,6 @@ pub use self::iterator::*; // --------------------- Array Equality --------------------- -pub use self::equal::ArrayEqual; pub use self::equal_json::JsonEqual; // --------------------- Array's values comparison --------------------- diff --git a/rust/arrow/src/compute/kernels/concat.rs b/rust/arrow/src/compute/kernels/concat.rs index d07d35e82b1..8c0965adbf8 100644 --- a/rust/arrow/src/compute/kernels/concat.rs +++ b/rust/arrow/src/compute/kernels/concat.rs @@ -220,12 +220,7 @@ mod tests { Some("baz"), ])) as ArrayRef; - assert!( - arr.equals(&(*expected_output)), - "expect {:#?} to be: {:#?}", - arr, - &expected_output - ); + assert_eq!(&arr, &expected_output); Ok(()) } @@ -268,12 +263,7 @@ mod tests { Some(1024), ])) as ArrayRef; - assert!( - arr.equals(&(*expected_output)), - "expect {:#?} to be: {:#?}", - arr, - &expected_output - ); + assert_eq!(&arr, &expected_output); Ok(()) } @@ -310,12 +300,7 @@ mod tests { Some(false), ])) as ArrayRef; - assert!( - arr.equals(&(*expected_output)), - "expect {:#?} to be: {:#?}", - arr, - &expected_output - ); + assert_eq!(&arr, &expected_output); Ok(()) } @@ -379,14 +364,9 @@ mod tests { Arc::new(builder_in3.finish()), ])?; - let array_expected = builder_expected.finish(); + let array_expected = Arc::new(builder_expected.finish()) as ArrayRef; - assert!( - array_result.equals(&array_expected), - "expect {:#?} to be: {:#?}", - array_result, - &array_expected - ); + assert_eq!(&array_result, &array_expected); Ok(()) } diff --git a/rust/arrow/src/compute/kernels/sort.rs b/rust/arrow/src/compute/kernels/sort.rs index dcc8e391caf..da570ef3241 100644 --- a/rust/arrow/src/compute/kernels/sort.rs +++ b/rust/arrow/src/compute/kernels/sort.rs @@ -565,12 +565,12 @@ mod tests { expected_data: Vec, ) where T: ArrowPrimitiveType, - PrimitiveArray: From>> + ArrayEqual, + PrimitiveArray: From>>, { let output = PrimitiveArray::::from(data); let expected = UInt32Array::from(expected_data); let output = sort_to_indices(&(Arc::new(output) as ArrayRef), options).unwrap(); - assert!(output.equals(&expected)) + assert_eq!(output, expected) } fn test_sort_primitive_arrays( @@ -579,13 +579,12 @@ mod tests { expected_data: Vec>, ) where T: ArrowPrimitiveType, - PrimitiveArray: From>> + ArrayEqual, + PrimitiveArray: From>>, { let output = PrimitiveArray::::from(data); - let expected = PrimitiveArray::::from(expected_data); + let expected = Arc::new(PrimitiveArray::::from(expected_data)) as ArrayRef; let output = sort(&(Arc::new(output) as ArrayRef), options).unwrap(); - let output = output.as_any().downcast_ref::>().unwrap(); - assert!(output.equals(&expected)) + assert_eq!(&output, &expected) } fn test_sort_to_indices_string_arrays( @@ -596,7 +595,7 @@ mod tests { let output = StringArray::from(data); let expected = UInt32Array::from(expected_data); let output = sort_to_indices(&(Arc::new(output) as ArrayRef), options).unwrap(); - assert!(output.equals(&expected)) + assert_eq!(output, expected) } fn test_sort_string_arrays( @@ -605,10 +604,9 @@ mod tests { expected_data: Vec>, ) { let output = StringArray::from(data); - let expected = StringArray::from(expected_data); + let expected = Arc::new(StringArray::from(expected_data)) as ArrayRef; let output = sort(&(Arc::new(output) as ArrayRef), options).unwrap(); - let output = output.as_any().downcast_ref::().unwrap(); - assert!(output.equals(&expected)) + assert_eq!(&output, &expected) } fn test_sort_string_dict_arrays( @@ -635,7 +633,7 @@ mod tests { .expect("Unable to get dictionary values"); let sorted_keys = sorted.keys_array(); - assert!(sorted_dict.equals(dict)); + assert_eq!(sorted_dict, dict); let sorted_strings = StringArray::try_from( (0..sorted.len()) @@ -652,32 +650,14 @@ mod tests { let expected = StringArray::try_from(expected_data).expect("Unable to create string array"); - assert!(sorted_strings.equals(&expected)) + assert_eq!(sorted_strings, expected) } fn test_lex_sort_arrays(input: Vec, expected_output: Vec) { let sorted = lexsort(&input).unwrap(); - let sorted2cmp = sorted.iter().map(|arr| -> Box<&dyn ArrayEqual> { - match arr.data_type() { - DataType::Int64 => Box::new(as_primitive_array::(&arr)), - DataType::UInt32 => Box::new(as_primitive_array::(&arr)), - DataType::Utf8 => Box::new(as_string_array(&arr)), - DataType::Dictionary(key_type, _) => match key_type.as_ref() { - DataType::Int8 => Box::new(as_dictionary_array::(&arr)), - DataType::Int16 => Box::new(as_dictionary_array::(&arr)), - DataType::Int32 => Box::new(as_dictionary_array::(&arr)), - _ => panic!("unexpected dictionary key type"), - }, - _ => panic!("unexpected array type"), - } - }); - for (i, values) in sorted2cmp.enumerate() { - assert!( - values.equals(&(*expected_output[i])), - "expect {:#?} to be: {:#?}", - sorted, - expected_output - ); + + for (result, expected) in sorted.iter().zip(expected_output.iter()) { + assert_eq!(result, expected); } } diff --git a/rust/arrow/src/compute/kernels/take.rs b/rust/arrow/src/compute/kernels/take.rs index 08016b96c80..9c9ca56fec3 100644 --- a/rust/arrow/src/compute/kernels/take.rs +++ b/rust/arrow/src/compute/kernels/take.rs @@ -460,16 +460,12 @@ mod tests { expected_data: Vec>, ) where T: ArrowPrimitiveType, - PrimitiveArray: From>> + ArrayEqual, + PrimitiveArray: From>>, { let output = PrimitiveArray::::from(data); - let expected = PrimitiveArray::::from(expected_data); + let expected = Arc::new(PrimitiveArray::::from(expected_data)) as ArrayRef; let output = take(&(Arc::new(output) as ArrayRef), index, options).unwrap(); - let output = output.as_any().downcast_ref::>().unwrap(); - assert!( - output.equals(&expected), - format!("{:?} =! {:?}", output.data(), expected.data()) - ) + assert_eq!(&output, &expected) } fn test_take_impl_primitive_arrays( @@ -479,7 +475,7 @@ mod tests { expected_data: Vec>, ) where T: ArrowPrimitiveType, - PrimitiveArray: From>> + ArrayEqual, + PrimitiveArray: From>>, I: ArrowNumericType, I::Native: ToPrimitive, { @@ -487,10 +483,7 @@ mod tests { let expected = PrimitiveArray::::from(expected_data); let output = take_impl(&(Arc::new(output) as ArrayRef), index, options).unwrap(); let output = output.as_any().downcast_ref::>().unwrap(); - assert!( - output.equals(&expected), - format!("{:?} =! {:?}", output.data(), expected.data()) - ) + assert_eq!(output, &expected) } // create a simple struct for testing purposes @@ -731,7 +724,7 @@ mod tests { fn _test_take_string<'a, K: 'static>() where - K: Array + From>>, + K: Array + PartialEq + From>>, { let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]); @@ -752,12 +745,7 @@ mod tests { let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]); - assert!( - actual.equals(&expected), - "{:?} != {:?}", - actual.data(), - expected.data() - ); + assert_eq!(actual, &expected); } #[test] @@ -828,7 +816,7 @@ mod tests { .build(); let expected_list_array = $list_array_type::from(expected_list_data); - assert!(a.equals(&expected_list_array)); + assert_eq!(a, &expected_list_array); }}; } @@ -902,7 +890,7 @@ mod tests { .build(); let expected_list_array = $list_array_type::from(expected_list_data); - assert!(a.equals(&expected_list_array)); + assert_eq!(a, &expected_list_array); }}; } @@ -976,7 +964,7 @@ mod tests { .build(); let expected_list_array = $list_array_type::from(expected_list_data); - assert!(a.equals(&expected_list_array)); + assert_eq!(a, &expected_list_array); }}; } @@ -1057,10 +1045,8 @@ mod tests { .add_child_data(expected_int_data) .build(); let struct_array = StructArray::from(struct_array_data); - assert!( - a.equals(&struct_array), - format!("{:?} =! {:?}", a.data(), struct_array.data()) - ); + + assert_eq!(a, &struct_array); } #[test] @@ -1090,10 +1076,7 @@ mod tests { .add_child_data(expected_int_data) .build(); let struct_array = StructArray::from(struct_array_data); - assert!( - a.equals(&struct_array), - format!("{:?} =! {:?}", a.data(), struct_array.data()) - ); + assert_eq!(a, &struct_array); } #[test]