diff --git a/rust/arrow/src/array/data.rs b/rust/arrow/src/array/data.rs index 5049ef1dd88..d634ed18d5c 100644 --- a/rust/arrow/src/array/data.rs +++ b/rust/arrow/src/array/data.rs @@ -28,7 +28,11 @@ use crate::{bitmap::Bitmap, datatypes::ArrowNativeType}; use super::equal::equal; #[inline] -fn count_nulls(null_bit_buffer: Option<&Buffer>, offset: usize, len: usize) -> usize { +pub(crate) fn count_nulls( + null_bit_buffer: Option<&Buffer>, + offset: usize, + len: usize, +) -> usize { if let Some(ref buf) = null_bit_buffer { len.checked_sub(buf.count_set_bits_offset(offset, len)) .unwrap() diff --git a/rust/arrow/src/array/equal/dictionary.rs b/rust/arrow/src/array/equal/dictionary.rs index a41b0a9b74e..087f8f8329b 100644 --- a/rust/arrow/src/array/equal/dictionary.rs +++ b/rust/arrow/src/array/equal/dictionary.rs @@ -40,6 +40,8 @@ pub(super) fn dictionary_equal( equal_range( lhs_values, rhs_values, + lhs_values.null_buffer(), + rhs_values.null_buffer(), lhs_keys[lhs_pos].to_usize().unwrap(), rhs_keys[rhs_pos].to_usize().unwrap(), 1, @@ -58,6 +60,8 @@ pub(super) fn dictionary_equal( && equal_range( lhs_values, rhs_values, + lhs_values.null_buffer(), + rhs_values.null_buffer(), lhs_keys[lhs_pos].to_usize().unwrap(), rhs_keys[rhs_pos].to_usize().unwrap(), 1, diff --git a/rust/arrow/src/array/equal/fixed_list.rs b/rust/arrow/src/array/equal/fixed_list.rs index aeb0d1372c8..ee0ee0f9c21 100644 --- a/rust/arrow/src/array/equal/fixed_list.rs +++ b/rust/arrow/src/array/equal/fixed_list.rs @@ -38,6 +38,8 @@ pub(super) fn fixed_list_equal( equal_range( lhs_values, rhs_values, + lhs_values.null_buffer(), + rhs_values.null_buffer(), size * lhs_start, size * rhs_start, size * len, @@ -56,6 +58,8 @@ pub(super) fn fixed_list_equal( && equal_range( lhs_values, rhs_values, + lhs_values.null_buffer(), + rhs_values.null_buffer(), lhs_pos * size, rhs_pos * size, size, // 1 * size since we are comparing a single entry diff --git a/rust/arrow/src/array/equal/list.rs b/rust/arrow/src/array/equal/list.rs index 6a9305edc11..c48c716d6e0 100644 --- a/rust/arrow/src/array/equal/list.rs +++ b/rust/arrow/src/array/equal/list.rs @@ -60,6 +60,8 @@ fn offset_value_equal( && equal_range( lhs_values, rhs_values, + lhs_values.null_buffer(), + rhs_values.null_buffer(), lhs_start, rhs_start, lhs_len.to_usize().unwrap(), @@ -86,6 +88,8 @@ pub(super) fn list_equal( ) && equal_range( lhs_values, rhs_values, + lhs_values.null_buffer(), + rhs_values.null_buffer(), lhs_offsets[lhs_start].to_usize().unwrap(), rhs_offsets[rhs_start].to_usize().unwrap(), (lhs_offsets[len] - lhs_offsets[lhs_start]) diff --git a/rust/arrow/src/array/equal/mod.rs b/rust/arrow/src/array/equal/mod.rs index 9ecdd078c3b..84d486d94b6 100644 --- a/rust/arrow/src/array/equal/mod.rs +++ b/rust/arrow/src/array/equal/mod.rs @@ -25,7 +25,10 @@ use super::{ PrimitiveArray, StringOffsetSizeTrait, StructArray, }; -use crate::datatypes::{ArrowPrimitiveType, DataType, IntervalUnit}; +use crate::{ + buffer::Buffer, + datatypes::{ArrowPrimitiveType, DataType, IntervalUnit}, +}; mod boolean; mod decimal; @@ -108,15 +111,49 @@ impl PartialEq for StructArray { } /// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively -/// for `len` slots. +/// for `len` slots. The null buffers `lhs_nulls` and `rhs_nulls` inherit parent nullability. +/// +/// If an array is a child of a struct or list, the array's nulls have to be merged with the parent. +/// This then affects the null count of the array, thus the merged nulls are passed separately +/// as `lhs_nulls` and `rhs_nulls` variables to functions. +/// The nulls are merged with a bitwise AND, and null counts are recomputed wheer necessary. #[inline] fn equal_values( lhs: &ArrayData, rhs: &ArrayData, + lhs_nulls: Option<&Buffer>, + rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { + // compute the nested buffer of the parent and child + // if the array has no parent, the child is computed with itself + #[allow(unused_assignments)] + let mut temp_lhs: Option = None; + #[allow(unused_assignments)] + let mut temp_rhs: Option = None; + let lhs_merged_nulls = match (lhs_nulls, lhs.null_buffer()) { + (None, None) => None, + (None, Some(c)) => Some(c), + (Some(p), None) => Some(p), + (Some(p), Some(c)) => { + let merged = (p & c).unwrap(); + temp_lhs = Some(merged); + temp_lhs.as_ref() + } + }; + let rhs_merged_nulls = match (rhs_nulls, rhs.null_buffer()) { + (None, None) => None, + (None, Some(c)) => Some(c), + (Some(p), None) => Some(p), + (Some(p), Some(c)) => { + let merged = (p & c).unwrap(); + temp_rhs = Some(merged); + temp_rhs.as_ref() + } + }; + match lhs.data_type() { DataType::Null => null_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Boolean => boolean_equal(lhs, rhs, lhs_start, rhs_start, len), @@ -142,12 +179,24 @@ fn equal_values( | DataType::Duration(_) => { primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) } - DataType::Utf8 | DataType::Binary => { - variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) - } - DataType::LargeUtf8 | DataType::LargeBinary => { - variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) - } + DataType::Utf8 | DataType::Binary => variable_sized_equal::( + lhs, + rhs, + lhs_merged_nulls, + rhs_merged_nulls, + lhs_start, + rhs_start, + len, + ), + DataType::LargeUtf8 | DataType::LargeBinary => variable_sized_equal::( + lhs, + rhs, + lhs_merged_nulls, + rhs_merged_nulls, + lhs_start, + rhs_start, + len, + ), DataType::FixedSizeBinary(_) => { fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len) } @@ -157,7 +206,15 @@ fn equal_values( DataType::FixedSizeList(_, _) => { fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len) } - DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Struct(_) => struct_equal( + lhs, + rhs, + lhs_merged_nulls, + rhs_merged_nulls, + lhs_start, + rhs_start, + len, + ), DataType::Union(_) => unimplemented!("See ARROW-8576"), DataType::Dictionary(data_type, _) => match data_type.as_ref() { DataType::Int8 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), @@ -191,13 +248,15 @@ fn equal_values( fn equal_range( lhs: &ArrayData, rhs: &ArrayData, + lhs_nulls: Option<&Buffer>, + rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { utils::base_equal(lhs, rhs) - && utils::equal_nulls(lhs, rhs, lhs_start, rhs_start, len) - && equal_values(lhs, rhs, lhs_start, rhs_start, len) + && utils::equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) } /// Logically compares two [ArrayData]. @@ -213,10 +272,12 @@ fn equal_range( /// This function may panic whenever any of the [ArrayData] does not follow the Arrow specification. /// (e.g. wrong number of buffers, buffer `len` does not correspond to the declared `len`) pub fn equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { + let lhs_nulls = lhs.null_buffer(); + let rhs_nulls = rhs.null_buffer(); utils::base_equal(lhs, rhs) && lhs.null_count() == rhs.null_count() - && utils::equal_nulls(lhs, rhs, 0, 0, lhs.len()) - && equal_values(lhs, rhs, 0, 0, lhs.len()) + && utils::equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, 0, 0, lhs.len()) + && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, 0, 0, lhs.len()) } #[cfg(test)] @@ -231,7 +292,8 @@ mod tests { StringDictionaryBuilder, StringOffsetSizeTrait, StructArray, }; use crate::array::{GenericStringArray, Int32Array}; - use crate::datatypes::Int16Type; + use crate::buffer::Buffer; + use crate::datatypes::{Field, Int16Type}; use super::*; @@ -841,6 +903,180 @@ mod tests { test_equal(a.as_ref(), b.as_ref(), true); } + #[test] + fn test_struct_equal_null() { + let strings: ArrayRef = Arc::new(StringArray::from(vec![ + Some("joe"), + None, + None, + Some("mark"), + Some("doe"), + ])); + let ints: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + ])); + let ints_non_null: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 0])); + + let a = ArrayData::builder(DataType::Struct(vec![ + Field::new("f1", DataType::Utf8, true), + Field::new("f2", DataType::Int32, true), + ])) + .null_bit_buffer(Buffer::from(vec![0b00001011])) + .len(5) + .null_count(2) + .add_child_data(strings.data_ref().clone()) + .add_child_data(ints.data_ref().clone()) + .build(); + let a = crate::array::make_array(a); + + let b = ArrayData::builder(DataType::Struct(vec![ + Field::new("f1", DataType::Utf8, true), + Field::new("f2", DataType::Int32, true), + ])) + .null_bit_buffer(Buffer::from(vec![0b00001011])) + .len(5) + .null_count(2) + .add_child_data(strings.data_ref().clone()) + .add_child_data(ints_non_null.data_ref().clone()) + .build(); + let b = crate::array::make_array(b); + + test_equal(a.data_ref(), b.data_ref(), true); + + // test with arrays that are not equal + let c_ints_non_null: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 0, 4])); + let c = ArrayData::builder(DataType::Struct(vec![ + Field::new("f1", DataType::Utf8, true), + Field::new("f2", DataType::Int32, true), + ])) + .null_bit_buffer(Buffer::from(vec![0b00001011])) + .len(5) + .null_count(2) + .add_child_data(strings.data_ref().clone()) + .add_child_data(c_ints_non_null.data_ref().clone()) + .build(); + let c = crate::array::make_array(c); + + test_equal(a.data_ref(), c.data_ref(), false); + + // test a nested struct + let a = ArrayData::builder(DataType::Struct(vec![Field::new( + "f3", + a.data_type().clone(), + true, + )])) + .null_bit_buffer(Buffer::from(vec![0b00011110])) + .len(5) + .null_count(1) + .add_child_data(a.data_ref().clone()) + .build(); + let a = crate::array::make_array(a); + + // reconstruct b, but with different data where the first struct is null + let strings: ArrayRef = Arc::new(StringArray::from(vec![ + Some("joanne"), // difference + None, + None, + Some("mark"), + Some("doe"), + ])); + let b = ArrayData::builder(DataType::Struct(vec![ + Field::new("f1", DataType::Utf8, true), + Field::new("f2", DataType::Int32, true), + ])) + .null_bit_buffer(Buffer::from(vec![0b00001011])) + .len(5) + .null_count(2) + .add_child_data(strings.data_ref().clone()) + .add_child_data(ints_non_null.data_ref().clone()) + .build(); + + let b = ArrayData::builder(DataType::Struct(vec![Field::new( + "f3", + b.data_type().clone(), + true, + )])) + .null_bit_buffer(Buffer::from(vec![0b00011110])) + .len(5) + .null_count(1) + .add_child_data(b) + .build(); + let b = crate::array::make_array(b); + + test_equal(a.data_ref(), b.data_ref(), true); + } + + #[test] + fn test_struct_equal_null_variable_size() { + // the string arrays differ, but where the struct array is null + let strings1: ArrayRef = Arc::new(StringArray::from(vec![ + Some("joe"), + None, + None, + Some("mark"), + Some("doel"), + ])); + let strings2: ArrayRef = Arc::new(StringArray::from(vec![ + Some("joel"), + None, + None, + Some("mark"), + Some("doe"), + ])); + + let a = ArrayData::builder(DataType::Struct(vec![Field::new( + "f1", + DataType::Utf8, + true, + )])) + .null_bit_buffer(Buffer::from(vec![0b00001010])) + .len(5) + .null_count(3) + .add_child_data(strings1.data_ref().clone()) + .build(); + let a = crate::array::make_array(a); + + let b = ArrayData::builder(DataType::Struct(vec![Field::new( + "f1", + DataType::Utf8, + true, + )])) + .null_bit_buffer(Buffer::from(vec![0b00001010])) + .len(5) + .null_count(3) + .add_child_data(strings2.data_ref().clone()) + .build(); + let b = crate::array::make_array(b); + + test_equal(a.data_ref(), b.data_ref(), true); + + // test with arrays that are not equal + let strings3: ArrayRef = Arc::new(StringArray::from(vec![ + Some("mark"), + None, + None, + Some("doe"), + Some("joe"), + ])); + let c = ArrayData::builder(DataType::Struct(vec![Field::new( + "f1", + DataType::Utf8, + true, + )])) + .null_bit_buffer(Buffer::from(vec![0b00001011])) + .len(5) + .null_count(2) + .add_child_data(strings3.data_ref().clone()) + .build(); + let c = crate::array::make_array(c); + + test_equal(a.data_ref(), c.data_ref(), false); + } + fn create_dictionary_array(values: &[&str], keys: &[Option<&str>]) -> ArrayDataRef { let values = StringArray::from(values.to_vec()); let mut builder = StringDictionaryBuilder::new_with_dictionary( diff --git a/rust/arrow/src/array/equal/structure.rs b/rust/arrow/src/array/equal/structure.rs index 1e8a1ff260b..5b12ae776d9 100644 --- a/rust/arrow/src/array/equal/structure.rs +++ b/rust/arrow/src/array/equal/structure.rs @@ -15,45 +15,97 @@ // specific language governing permissions and limitations // under the License. -use crate::array::ArrayData; +use crate::{ + array::data::count_nulls, array::ArrayData, buffer::Buffer, util::bit_util::get_bit, +}; use super::equal_range; +/// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively +/// for `len` slots. The null buffers `lhs_nulls` and `rhs_nulls` inherit parent nullability. +/// +/// If an array is a child of a struct or list, the array's nulls have to be merged with the parent. +/// This then affects the null count of the array, thus the merged nulls are passed separately +/// as `lhs_nulls` and `rhs_nulls` variables to functions. +/// The nulls are merged with a bitwise AND, and null counts are recomputed wheer necessary. fn equal_values( lhs: &ArrayData, rhs: &ArrayData, + lhs_nulls: Option<&Buffer>, + rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { + let mut temp_lhs: Option = None; + let mut temp_rhs: Option = None; + lhs.child_data() .iter() .zip(rhs.child_data()) .all(|(lhs_values, rhs_values)| { - equal_range(lhs_values, rhs_values, lhs_start, rhs_start, len) + // merge the null data + let lhs_merged_nulls = match (lhs_nulls, lhs_values.null_buffer()) { + (None, None) => None, + (None, Some(c)) => Some(c), + (Some(p), None) => Some(p), + (Some(p), Some(c)) => { + let merged = (p & c).unwrap(); + temp_lhs = Some(merged); + temp_lhs.as_ref() + } + }; + let rhs_merged_nulls = match (rhs_nulls, rhs_values.null_buffer()) { + (None, None) => None, + (None, Some(c)) => Some(c), + (Some(p), None) => Some(p), + (Some(p), Some(c)) => { + let merged = (p & c).unwrap(); + temp_rhs = Some(merged); + temp_rhs.as_ref() + } + }; + equal_range( + lhs_values, + rhs_values, + lhs_merged_nulls, + rhs_merged_nulls, + lhs_start, + rhs_start, + len, + ) }) } pub(super) fn struct_equal( lhs: &ArrayData, rhs: &ArrayData, + lhs_nulls: Option<&Buffer>, + rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { - if lhs.null_count() == 0 && rhs.null_count() == 0 { - equal_values(lhs, rhs, lhs_start, rhs_start, len) + // we have to recalculate null counts from the null buffers + let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); + let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + if lhs_null_count == 0 && rhs_null_count == 0 { + equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) } else { + // get a ref of the null buffer bytes, to use in testing for nullness + let lhs_null_bytes = lhs_nulls.as_ref().unwrap().data(); + let rhs_null_bytes = rhs_nulls.as_ref().unwrap().data(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; - let lhs_is_null = lhs.is_null(lhs_pos); - let rhs_is_null = rhs.is_null(rhs_pos); + // if both struct and child had no null buffers, + let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos); + let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos); lhs_is_null || (lhs_is_null == rhs_is_null) - && equal_values(lhs, rhs, lhs_pos, rhs_pos, 1) + && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_pos, rhs_pos, 1) }) } } diff --git a/rust/arrow/src/array/equal/utils.rs b/rust/arrow/src/array/equal/utils.rs index f9e8860a5bb..3bb4c0be653 100644 --- a/rust/arrow/src/array/equal/utils.rs +++ b/rust/arrow/src/array/equal/utils.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{array::ArrayData, util::bit_util}; +use crate::{array::data::count_nulls, array::ArrayData, buffer::Buffer, util::bit_util}; // whether bits along the positions are equal // `lhs_start`, `rhs_start` and `len` are _measured in bits_. @@ -37,15 +37,17 @@ pub(super) fn equal_bits( pub(super) fn equal_nulls( lhs: &ArrayData, rhs: &ArrayData, + lhs_nulls: Option<&Buffer>, + rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { - if lhs.null_count() > 0 || rhs.null_count() > 0 { - let lhs_null_bitmap = lhs.null_bitmap().as_ref().unwrap(); - let rhs_null_bitmap = rhs.null_bitmap().as_ref().unwrap(); - let lhs_values = lhs_null_bitmap.bits.data(); - let rhs_values = rhs_null_bitmap.bits.data(); + let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); + let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + if lhs_null_count > 0 || rhs_null_count > 0 { + let lhs_values = lhs_nulls.unwrap().data(); + let rhs_values = rhs_nulls.unwrap().data(); equal_bits( lhs_values, rhs_values, diff --git a/rust/arrow/src/array/equal/variable_size.rs b/rust/arrow/src/array/equal/variable_size.rs index c26ec6cc1b3..caf8a0c1eae 100644 --- a/rust/arrow/src/array/equal/variable_size.rs +++ b/rust/arrow/src/array/equal/variable_size.rs @@ -15,7 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{ArrayData, OffsetSizeTrait}; +use crate::buffer::Buffer; +use crate::util::bit_util::get_bit; +use crate::{ + array::data::count_nulls, + array::{ArrayData, OffsetSizeTrait}, +}; use super::utils::equal_len; @@ -46,6 +51,8 @@ fn offset_value_equal( pub(super) fn variable_sized_equal( lhs: &ArrayData, rhs: &ArrayData, + lhs_nulls: Option<&Buffer>, + rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -57,8 +64,11 @@ pub(super) fn variable_sized_equal( let lhs_values = &lhs.buffers()[1].data()[lhs.offset()..]; let rhs_values = &rhs.buffers()[1].data()[rhs.offset()..]; - if lhs.null_count() == 0 - && rhs.null_count() == 0 + let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); + let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + + if lhs_null_count == 0 + && rhs_null_count == 0 && !lhs_values.is_empty() && !rhs_values.is_empty() { @@ -76,8 +86,13 @@ pub(super) fn variable_sized_equal( let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; - let lhs_is_null = lhs.is_null(lhs_pos); - let rhs_is_null = rhs.is_null(rhs_pos); + // the null bits can still be `None`, so we don't unwrap + let lhs_is_null = !lhs_nulls + .map(|v| get_bit(v.data(), lhs_pos)) + .unwrap_or(false); + let rhs_is_null = !rhs_nulls + .map(|v| get_bit(v.data(), rhs_pos)) + .unwrap_or(false); lhs_is_null || (lhs_is_null == rhs_is_null)