diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index dac331b44dd..b92fef3341e 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -88,6 +88,10 @@ harness = false name = "comparison_kernels" harness = false +[[bench]] +name = "filter_kernels" +harness = false + [[bench]] name = "take_kernels" harness = false diff --git a/rust/arrow/benches/filter_kernels.rs b/rust/arrow/benches/filter_kernels.rs new file mode 100644 index 00000000000..75c04352c0a --- /dev/null +++ b/rust/arrow/benches/filter_kernels.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::*; +use arrow::compute::{filter, FilterContext}; +use arrow::datatypes::ArrowNumericType; +use criterion::{criterion_group, criterion_main, Criterion}; + +fn create_primitive_array(size: usize, value_fn: F) -> PrimitiveArray +where + T: ArrowNumericType, + F: Fn(usize) -> T::Native, +{ + let mut builder = PrimitiveArray::::builder(size); + for i in 0..size { + builder.append_value(value_fn(i)).unwrap(); + } + builder.finish() +} + +fn create_u8_array_with_nulls(size: usize) -> UInt8Array { + let mut builder = UInt8Builder::new(size); + for i in 0..size { + if i % 2 == 0 { + builder.append_value(1).unwrap(); + } else { + builder.append_null().unwrap(); + } + } + builder.finish() +} + +fn create_bool_array(size: usize, value_fn: F) -> BooleanArray +where + F: Fn(usize) -> bool, +{ + let mut builder = BooleanBuilder::new(size); + for i in 0..size { + builder.append_value(value_fn(i)).unwrap(); + } + builder.finish() +} + +fn bench_filter_u8(data_array: &UInt8Array, filter_array: &BooleanArray) { + filter( + criterion::black_box(data_array), + criterion::black_box(filter_array), + ) + .unwrap(); +} + +// fn bench_filter_f32(data_array: &Float32Array, filter_array: &BooleanArray) { +// filter(criterion::black_box(data_array), criterion::black_box(filter_array)).unwrap(); +// } + +fn bench_filter_context_u8(data_array: &UInt8Array, filter_context: &FilterContext) { + filter_context + .filter(criterion::black_box(data_array)) + .unwrap(); +} + +fn bench_filter_context_f32(data_array: &Float32Array, filter_context: &FilterContext) { + filter_context + .filter(criterion::black_box(data_array)) + .unwrap(); +} + +fn add_benchmark(c: &mut Criterion) { + let size = 65536; + let filter_array = create_bool_array(size, |i| match i % 2 { + 0 => true, + _ => false, + }); + let sparse_filter_array = create_bool_array(size, |i| match i % 8000 { + 0 => true, + _ => false, + }); + let dense_filter_array = create_bool_array(size, |i| match i % 8000 { + 0 => false, + _ => true, + }); + + let filter_context = FilterContext::new(&filter_array).unwrap(); + let sparse_filter_context = FilterContext::new(&sparse_filter_array).unwrap(); + let dense_filter_context = FilterContext::new(&dense_filter_array).unwrap(); + + let data_array = create_primitive_array(size, |i| match i % 2 { + 0 => 1, + _ => 0, + }); + c.bench_function("filter u8 low selectivity", |b| { + b.iter(|| bench_filter_u8(&data_array, &filter_array)) + }); + c.bench_function("filter u8 high selectivity", |b| { + b.iter(|| bench_filter_u8(&data_array, &sparse_filter_array)) + }); + c.bench_function("filter u8 very low selectivity", |b| { + b.iter(|| bench_filter_u8(&data_array, &dense_filter_array)) + }); + + c.bench_function("filter context u8 low selectivity", |b| { + b.iter(|| bench_filter_context_u8(&data_array, &filter_context)) + }); + c.bench_function("filter context u8 high selectivity", |b| { + b.iter(|| bench_filter_context_u8(&data_array, &sparse_filter_context)) + }); + c.bench_function("filter context u8 very low selectivity", |b| { + b.iter(|| bench_filter_context_u8(&data_array, &dense_filter_context)) + }); + + let data_array = create_u8_array_with_nulls(size); + c.bench_function("filter context u8 w NULLs low selectivity", |b| { + b.iter(|| bench_filter_context_u8(&data_array, &filter_context)) + }); + c.bench_function("filter context u8 w NULLs high selectivity", |b| { + b.iter(|| bench_filter_context_u8(&data_array, &sparse_filter_context)) + }); + c.bench_function("filter context u8 w NULLs very low selectivity", |b| { + b.iter(|| bench_filter_context_u8(&data_array, &dense_filter_context)) + }); + + let data_array = create_primitive_array(size, |i| match i % 2 { + 0 => 1.0, + _ => 0.0, + }); + c.bench_function("filter context f32 low selectivity", |b| { + b.iter(|| bench_filter_context_f32(&data_array, &filter_context)) + }); + c.bench_function("filter context f32 high selectivity", |b| { + b.iter(|| bench_filter_context_f32(&data_array, &sparse_filter_context)) + }); + c.bench_function("filter context f32 very low selectivity", |b| { + b.iter(|| bench_filter_context_f32(&data_array, &dense_filter_context)) + }); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); diff --git a/rust/arrow/src/array/data.rs b/rust/arrow/src/array/data.rs index f8bf5cf7756..ad949460d08 100644 --- a/rust/arrow/src/array/data.rs +++ b/rust/arrow/src/array/data.rs @@ -151,6 +151,7 @@ impl ArrayData { } /// Returns the offset of this array + #[inline] pub fn offset(&self) -> usize { self.offset } diff --git a/rust/arrow/src/buffer.rs b/rust/arrow/src/buffer.rs index ca01f1d150b..33a0af9e4bd 100644 --- a/rust/arrow/src/buffer.rs +++ b/rust/arrow/src/buffer.rs @@ -545,10 +545,12 @@ impl MutableBuffer { /// /// Note that this should be used cautiously, and the returned pointer should not be /// stored anywhere, to avoid dangling pointers. + #[inline] pub fn raw_data(&self) -> *const u8 { self.data } + #[inline] pub fn raw_data_mut(&mut self) -> *mut u8 { self.data } diff --git a/rust/arrow/src/compute/kernels/filter.rs b/rust/arrow/src/compute/kernels/filter.rs index 52e12cfef19..98d70f05ced 100644 --- a/rust/arrow/src/compute/kernels/filter.rs +++ b/rust/arrow/src/compute/kernels/filter.rs @@ -17,139 +17,465 @@ //! Defines miscellaneous array kernels. -use std::sync::Arc; - use crate::array::*; use crate::datatypes::{ArrowNumericType, DataType, TimeUnit}; use crate::error::{ArrowError, Result}; +use crate::record_batch::RecordBatch; +use crate::{ + bitmap::Bitmap, + buffer::{Buffer, MutableBuffer}, + util::bit_util, +}; +use std::{mem, sync::Arc}; -/// Helper function to perform boolean lambda function on values from two arrays. -fn bool_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - op: F, -) -> Result -where - T: ArrowNumericType, - F: Fn(Option, Option) -> bool, -{ - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform math operation on arrays of different length".to_string(), - )); +/// trait for copying filtered null bitmap bits +trait CopyNullBit { + fn copy_null_bit(&mut self, source_index: usize); + fn copy_null_bits(&mut self, source_index: usize, count: usize); + fn null_count(&self) -> usize; + fn null_buffer(&mut self) -> Buffer; +} + +/// no-op null bitmap copy implementation, +/// used when the filtered data array doesn't have a null bitmap +struct NullBitNoop {} + +impl NullBitNoop { + fn new() -> Self { + NullBitNoop {} } - let mut b = BooleanArray::builder(left.len()); - for i in 0..left.len() { - let index = i; - let l = if left.is_null(i) { - None - } else { - Some(left.value(index)) - }; - let r = if right.is_null(i) { - None - } else { - Some(right.value(index)) - }; - b.append_value(op(l, r))?; +} + +impl CopyNullBit for NullBitNoop { + #[inline] + fn copy_null_bit(&mut self, _source_index: usize) { + // do nothing + } + + #[inline] + fn copy_null_bits(&mut self, _source_index: usize, _count: usize) { + // do nothing + } + + fn null_count(&self) -> usize { + 0 + } + + fn null_buffer(&mut self) -> Buffer { + Buffer::from([0u8; 0]) } - Ok(b.finish()) } -macro_rules! filter_array { - ($array:expr, $filter:expr, $array_type:ident) => {{ - let b = $array.as_any().downcast_ref::<$array_type>().unwrap(); - let mut builder = $array_type::builder(b.len()); - for i in 0..b.len() { - if $filter.value(i) { - if b.is_null(i) { - builder.append_null()?; - } else { - builder.append_value(b.value(i))?; - } - } - } - Ok(Arc::new(builder.finish())) - }}; +/// null bitmap copy implementation, +/// used when the filtered data array has a null bitmap +struct NullBitSetter<'a> { + target_buffer: MutableBuffer, + source_bytes: &'a [u8], + target_index: usize, + null_count: usize, } -/// Returns the array, taking only the elements matching the filter -pub fn filter(array: &Array, filter: &BooleanArray) -> Result { - match array.data_type() { - DataType::UInt8 => filter_array!(array, filter, UInt8Array), - DataType::UInt16 => filter_array!(array, filter, UInt16Array), - DataType::UInt32 => filter_array!(array, filter, UInt32Array), - DataType::UInt64 => filter_array!(array, filter, UInt64Array), - DataType::Int8 => filter_array!(array, filter, Int8Array), - DataType::Int16 => filter_array!(array, filter, Int16Array), - DataType::Int32 => filter_array!(array, filter, Int32Array), - DataType::Int64 => filter_array!(array, filter, Int64Array), - DataType::Float32 => filter_array!(array, filter, Float32Array), - DataType::Float64 => filter_array!(array, filter, Float64Array), - DataType::Boolean => filter_array!(array, filter, BooleanArray), - DataType::Date32(_) => filter_array!(array, filter, Date32Array), - DataType::Date64(_) => filter_array!(array, filter, Date64Array), - DataType::Time32(TimeUnit::Second) => { - filter_array!(array, filter, Time32SecondArray) - } - DataType::Time32(TimeUnit::Millisecond) => { - filter_array!(array, filter, Time32MillisecondArray) +impl<'a> NullBitSetter<'a> { + fn new(null_bitmap: &'a Bitmap) -> Self { + let null_bytes = null_bitmap.buffer_ref().data(); + // create null bitmap buffer with same length and initialize null bitmap buffer to 1s + let null_buffer = + MutableBuffer::new(null_bytes.len()).with_bitset(null_bytes.len(), true); + NullBitSetter { + source_bytes: null_bytes, + target_buffer: null_buffer, + target_index: 0, + null_count: 0, } - DataType::Time64(TimeUnit::Microsecond) => { - filter_array!(array, filter, Time64MicrosecondArray) - } - DataType::Time64(TimeUnit::Nanosecond) => { - filter_array!(array, filter, Time64NanosecondArray) - } - DataType::Duration(TimeUnit::Second) => { - filter_array!(array, filter, DurationSecondArray) - } - DataType::Duration(TimeUnit::Millisecond) => { - filter_array!(array, filter, DurationMillisecondArray) - } - DataType::Duration(TimeUnit::Microsecond) => { - filter_array!(array, filter, DurationMicrosecondArray) - } - DataType::Duration(TimeUnit::Nanosecond) => { - filter_array!(array, filter, DurationNanosecondArray) + } +} + +impl<'a> CopyNullBit for NullBitSetter<'a> { + #[inline] + fn copy_null_bit(&mut self, source_index: usize) { + if !bit_util::get_bit(self.source_bytes, source_index) { + bit_util::unset_bit(self.target_buffer.data_mut(), self.target_index); + self.null_count += 1; } - DataType::Timestamp(TimeUnit::Second, _) => { - filter_array!(array, filter, TimestampSecondArray) + self.target_index += 1; + } + + #[inline] + fn copy_null_bits(&mut self, source_index: usize, count: usize) { + for i in 0..count { + self.copy_null_bit(source_index + i); } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - filter_array!(array, filter, TimestampMillisecondArray) + } + + fn null_count(&self) -> usize { + self.null_count + } + + fn null_buffer(&mut self) -> Buffer { + self.target_buffer.resize(self.target_index).unwrap(); + // use mem::replace to detach self.target_buffer from self so that it can be returned + let target_buffer = mem::replace(&mut self.target_buffer, MutableBuffer::new(0)); + target_buffer.freeze() + } +} + +fn get_null_bit_setter<'a>(data_array: &'a impl Array) -> Box { + if let Some(null_bitmap) = data_array.data_ref().null_bitmap() { + // only return an actual null bit copy implementation if null_bitmap is set + Box::new(NullBitSetter::new(null_bitmap)) + } else { + // otherwise return a no-op copy null bit implementation + // for improved performance when the filtered array doesn't contain NULLs + Box::new(NullBitNoop::new()) + } +} + +// transmute filter array to u64 +// - optimize filtering with highly selective filters by skipping entire batches of 64 filter bits +// - if the data array being filtered doesn't have a null bitmap, no time is wasted to copy a null bitmap +fn filter_array_impl( + filter_context: &FilterContext, + data_array: &impl Array, + array_type: DataType, + value_size: usize, +) -> Result { + if filter_context.filter_len > data_array.len() { + return Err(ArrowError::ComputeError( + "Filter array cannot be larger than data array".to_string(), + )); + } + let filtered_count = filter_context.filtered_count; + let filter_mask = &filter_context.filter_mask; + let filter_u64 = &filter_context.filter_u64; + let data_bytes = data_array.data_ref().buffers()[0].data(); + let mut target_buffer = MutableBuffer::new(filtered_count * value_size); + target_buffer.resize(filtered_count * value_size)?; + let target_bytes = target_buffer.data_mut(); + let mut target_byte_index: usize = 0; + let mut null_bit_setter = get_null_bit_setter(data_array); + let null_bit_setter = null_bit_setter.as_mut(); + let all_ones_batch = !0u64; + let data_array_offset = data_array.offset(); + + for (i, filter_batch) in filter_u64.iter().enumerate() { + // foreach u64 batch + let filter_batch = *filter_batch; + if filter_batch == 0 { + // if batch == 0: skip + continue; + } else if filter_batch == all_ones_batch { + // if batch == all 1s: copy all 64 values in one go + let data_index = (i * 64) + data_array_offset; + null_bit_setter.copy_null_bits(data_index, 64); + let data_byte_index = data_index * value_size; + let data_len = value_size * 64; + target_bytes[target_byte_index..(target_byte_index + data_len)] + .copy_from_slice( + &data_bytes[data_byte_index..(data_byte_index + data_len)], + ); + target_byte_index += data_len; + continue; } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - filter_array!(array, filter, TimestampMicrosecondArray) + for (j, filter_mask) in filter_mask.iter().enumerate() { + // foreach bit in batch: + if (filter_batch & *filter_mask) != 0 { + let data_index = (i * 64) + j + data_array_offset; + null_bit_setter.copy_null_bit(data_index); + // if filter bit == 1: copy data value bytes + let data_byte_index = data_index * value_size; + target_bytes[target_byte_index..(target_byte_index + value_size)] + .copy_from_slice( + &data_bytes[data_byte_index..(data_byte_index + value_size)], + ); + target_byte_index += value_size; + } } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - filter_array!(array, filter, TimestampNanosecondArray) + } + + let mut array_data_builder = ArrayDataBuilder::new(array_type) + .len(filtered_count) + .add_buffer(target_buffer.freeze()); + if null_bit_setter.null_count() > 0 { + array_data_builder = array_data_builder + .null_count(null_bit_setter.null_count()) + .null_bit_buffer(null_bit_setter.null_buffer()); + } + + Ok(array_data_builder) +} + +/// FilterContext can be used to improve performance when +/// filtering multiple data arrays with the same filter array. +#[derive(Debug)] +pub struct FilterContext { + filter_u64: Vec, + filter_len: usize, + filtered_count: usize, + filter_mask: Vec, +} + +macro_rules! filter_primitive_array { + ($context:expr, $array:expr, $array_type:ident) => {{ + let input_array = $array.as_any().downcast_ref::<$array_type>().unwrap(); + let output_array = $context.filter_primitive_array(input_array)?; + Ok(Arc::new(output_array)) + }}; +} + +macro_rules! filter_dictionary_array { + ($context:expr, $array:expr, $array_type:ident) => {{ + let input_array = $array.as_any().downcast_ref::<$array_type>().unwrap(); + let output_array = $context.filter_dictionary_array(input_array)?; + Ok(Arc::new(output_array)) + }}; +} + +impl FilterContext { + /// Returns a new instance of FilterContext + pub fn new(filter_array: &BooleanArray) -> Result { + if filter_array.offset() > 0 { + return Err(ArrowError::ComputeError( + "Filter array cannot have offset > 0".to_string(), + )); } - DataType::Binary => { - let b = array.as_any().downcast_ref::().unwrap(); - let mut values: Vec<&[u8]> = Vec::with_capacity(b.len()); - for i in 0..b.len() { - if filter.value(i) { - values.push(b.value(i)); + let filter_mask: Vec = (0..64).map(|x| 1u64 << x).collect(); + let filter_bytes = filter_array.data_ref().buffers()[0].data(); + let filtered_count = bit_util::count_set_bits(filter_bytes); + // transmute filter_bytes to &[u64] + let mut u64_buffer = MutableBuffer::new(filter_bytes.len()); + u64_buffer + .write_bytes(filter_bytes, u64_buffer.capacity() - filter_bytes.len())?; + let filter_u64 = u64_buffer.typed_data_mut::().to_owned(); + Ok(FilterContext { + filter_u64, + filter_len: filter_array.len(), + filtered_count, + filter_mask, + }) + } + + /// Returns a new array, containing only the elements matching the filter + pub fn filter(&self, array: &Array) -> Result { + match array.data_type() { + DataType::UInt8 => filter_primitive_array!(self, array, UInt8Array), + DataType::UInt16 => filter_primitive_array!(self, array, UInt16Array), + DataType::UInt32 => filter_primitive_array!(self, array, UInt32Array), + DataType::UInt64 => filter_primitive_array!(self, array, UInt64Array), + DataType::Int8 => filter_primitive_array!(self, array, Int8Array), + DataType::Int16 => filter_primitive_array!(self, array, Int16Array), + DataType::Int32 => filter_primitive_array!(self, array, Int32Array), + DataType::Int64 => filter_primitive_array!(self, array, Int64Array), + DataType::Float32 => filter_primitive_array!(self, array, Float32Array), + DataType::Float64 => filter_primitive_array!(self, array, Float64Array), + DataType::Boolean => { + let input_array = array.as_any().downcast_ref::().unwrap(); + let mut builder = BooleanArray::builder(self.filtered_count); + for i in 0..self.filter_u64.len() { + // foreach u64 batch + let filter_batch = self.filter_u64[i]; + if filter_batch == 0 { + // if batch == 0: skip + continue; + } + for j in 0..64 { + // foreach bit in batch: + if (filter_batch & self.filter_mask[j]) != 0 { + let data_index = (i * 64) + j; + if input_array.is_null(data_index) { + builder.append_null()?; + } else { + builder.append_value(input_array.value(data_index))?; + } + } + } } + Ok(Arc::new(builder.finish())) + }, + DataType::Date32(_) => filter_primitive_array!(self, array, Date32Array), + DataType::Date64(_) => filter_primitive_array!(self, array, Date64Array), + DataType::Time32(TimeUnit::Second) => { + filter_primitive_array!(self, array, Time32SecondArray) } - Ok(Arc::new(BinaryArray::from(values))) - } - DataType::Utf8 => { - let b = array.as_any().downcast_ref::().unwrap(); - let mut values: Vec<&str> = Vec::with_capacity(b.len()); - for i in 0..b.len() { - if filter.value(i) { - values.push(b.value(i)); + DataType::Time32(TimeUnit::Millisecond) => { + filter_primitive_array!(self, array, Time32MillisecondArray) + } + DataType::Time64(TimeUnit::Microsecond) => { + filter_primitive_array!(self, array, Time64MicrosecondArray) + } + DataType::Time64(TimeUnit::Nanosecond) => { + filter_primitive_array!(self, array, Time64NanosecondArray) + } + DataType::Duration(TimeUnit::Second) => { + filter_primitive_array!(self, array, DurationSecondArray) + } + DataType::Duration(TimeUnit::Millisecond) => { + filter_primitive_array!(self, array, DurationMillisecondArray) + } + DataType::Duration(TimeUnit::Microsecond) => { + filter_primitive_array!(self, array, DurationMicrosecondArray) + } + DataType::Duration(TimeUnit::Nanosecond) => { + filter_primitive_array!(self, array, DurationNanosecondArray) + } + DataType::Timestamp(TimeUnit::Second, _) => { + filter_primitive_array!(self, array, TimestampSecondArray) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + filter_primitive_array!(self, array, TimestampMillisecondArray) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + filter_primitive_array!(self, array, TimestampMicrosecondArray) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + filter_primitive_array!(self, array, TimestampNanosecondArray) + } + DataType::Binary => { + let input_array = array.as_any().downcast_ref::().unwrap(); + let mut values: Vec<&[u8]> = Vec::with_capacity(self.filtered_count); + for i in 0..self.filter_u64.len() { + // foreach u64 batch + let filter_batch = self.filter_u64[i]; + if filter_batch == 0 { + // if batch == 0: skip + continue; + } + for j in 0..64 { + // foreach bit in batch: + if (filter_batch & self.filter_mask[j]) != 0 { + let data_index = (i * 64) + j; + values.push(input_array.value(data_index)); + } + } + } + Ok(Arc::new(BinaryArray::from(values))) + } + DataType::Utf8 => { + let input_array = array.as_any().downcast_ref::().unwrap(); + let mut values: Vec<&str> = Vec::with_capacity(self.filtered_count); + for i in 0..self.filter_u64.len() { + // foreach u64 batch + let filter_batch = self.filter_u64[i]; + if filter_batch == 0 { + // if batch == 0: skip + continue; + } + for j in 0..64 { + // foreach bit in batch: + if (filter_batch & self.filter_mask[j]) != 0 { + let data_index = (i * 64) + j; + values.push(input_array.value(data_index)); + } + } + } + Ok(Arc::new(StringArray::from(values))) + } + DataType::Dictionary(ref key_type, ref value_type) => match (key_type.as_ref(), value_type.as_ref()) { + (key_type, DataType::Utf8) => match key_type { + DataType::UInt8 => filter_dictionary_array!(self, array, UInt8DictionaryArray), + DataType::UInt16 => filter_dictionary_array!(self, array, UInt16DictionaryArray), + DataType::UInt32 => filter_dictionary_array!(self, array, UInt32DictionaryArray), + DataType::UInt64 => filter_dictionary_array!(self, array, UInt64DictionaryArray), + DataType::Int8 => filter_dictionary_array!(self, array, Int8DictionaryArray), + DataType::Int16 => filter_dictionary_array!(self, array, Int16DictionaryArray), + DataType::Int32 => filter_dictionary_array!(self, array, Int32DictionaryArray), + DataType::Int64 => filter_dictionary_array!(self, array, Int64DictionaryArray), + other => Err(ArrowError::ComputeError(format!( + "filter not supported for string dictionary with key of type {:?}", + other + ))) } + (key_type, value_type) => Err(ArrowError::ComputeError(format!( + "filter not supported for Dictionary({:?}, {:?})", + key_type, value_type + ))) } - Ok(Arc::new(StringArray::from(values))) + other => Err(ArrowError::ComputeError(format!( + "filter not supported for {:?}", + other + ))), } - other => Err(ArrowError::ComputeError(format!( - "filter not supported for {:?}", - other - ))), } + + /// Returns a new PrimitiveArray containing only those values from the array passed as the data_array parameter, + /// selected by the BooleanArray passed as the filter_array parameter + pub fn filter_primitive_array( + &self, + data_array: &PrimitiveArray, + ) -> Result> + where + T: ArrowNumericType, + { + let array_type = T::get_data_type(); + let value_size = mem::size_of::(); + let array_data_builder = + filter_array_impl(self, data_array, array_type, value_size)?; + let data = array_data_builder.build(); + Ok(PrimitiveArray::::from(data)) + } + + /// Returns a new DictionaryArray containing only those keys from the array passed as the data_array parameter, + /// selected by the BooleanArray passed as the filter_array parameter. The values are cloned from the data_array. + pub fn filter_dictionary_array( + &self, + data_array: &DictionaryArray, + ) -> Result> + where + T: ArrowNumericType, + { + let array_type = data_array.data_type().clone(); + let value_size = mem::size_of::(); + let mut array_data_builder = + filter_array_impl(self, data_array, array_type, value_size)?; + // copy dictionary values from input array + array_data_builder = + array_data_builder.add_child_data(data_array.values().data()); + let data = array_data_builder.build(); + Ok(DictionaryArray::::from(data)) + } +} + +/// Returns a new array, containing only the elements matching the filter. +pub fn filter(array: &Array, filter: &BooleanArray) -> Result { + FilterContext::new(filter)?.filter(array) +} + +/// Returns a new PrimitiveArray containing only those values from the array passed as the data_array parameter, +/// selected by the BooleanArray passed as the filter_array parameter +pub fn filter_primitive_array( + data_array: &PrimitiveArray, + filter_array: &BooleanArray, +) -> Result> +where + T: ArrowNumericType, +{ + FilterContext::new(filter_array)?.filter_primitive_array(data_array) +} + +/// Returns a new DictionaryArray containing only those keys from the array passed as the data_array parameter, +/// selected by the BooleanArray passed as the filter_array parameter. The values are cloned from the data_array. +pub fn filter_dictionary_array( + data_array: &DictionaryArray, + filter_array: &BooleanArray, +) -> Result> +where + T: ArrowNumericType, +{ + FilterContext::new(filter_array)?.filter_dictionary_array(data_array) +} + +/// Returns a new RecordBatch with arrays containing only values matching the filter. +/// The same FilterContext is re-used when filtering arrays in the RecordBatch for better performance. +pub fn filter_record_batch( + record_batch: &RecordBatch, + filter_array: &BooleanArray, +) -> Result { + let filter_context = FilterContext::new(filter_array)?; + let filtered_arrays = record_batch + .columns() + .iter() + .map(|a| filter_context.filter(a.as_ref())) + .collect::>>()?; + RecordBatch::try_new(record_batch.schema(), filtered_arrays) } #[cfg(test)] @@ -253,6 +579,73 @@ mod tests { assert_eq!(8, d.value(1)); } + #[test] + fn test_filter_array_slice() { + let a_slice = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4); + let a = a_slice.as_ref(); + let b = BooleanArray::from(vec![true, false, false, true]); + // filtering with sliced filter array is not currently supported + // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4); + // let b = b_slice.as_any().downcast_ref().unwrap(); + let c = filter(a, &b).unwrap(); + let d = c.as_ref().as_any().downcast_ref::().unwrap(); + assert_eq!(2, d.len()); + assert_eq!(6, d.value(0)); + assert_eq!(9, d.value(1)); + } + + #[test] + fn test_filter_array_low_density() { + // this test exercises the all 0's branch of the filter algorithm + let mut data_values = (1..=65).into_iter().collect::>(); + let mut filter_values = (1..=65) + .into_iter() + .map(|i| match i % 65 { + 0 => true, + _ => false, + }) + .collect::>(); + // set up two more values after the batch + data_values.extend_from_slice(&[66, 67]); + filter_values.extend_from_slice(&[false, true]); + let a = Int32Array::from(data_values); + let b = BooleanArray::from(filter_values); + let c = filter(&a, &b).unwrap(); + let d = c.as_ref().as_any().downcast_ref::().unwrap(); + assert_eq!(2, d.len()); + assert_eq!(65, d.value(0)); + assert_eq!(67, d.value(1)); + } + + #[test] + fn test_filter_array_high_density() { + // this test exercises the all 1's branch of the filter algorithm + let mut data_values = (1..=65).into_iter().map(|x| Some(x)).collect::>(); + let mut filter_values = (1..=65) + .into_iter() + .map(|i| match i % 65 { + 0 => false, + _ => true, + }) + .collect::>(); + // set second data value to null + data_values[1] = None; + // set up two more values after the batch + data_values.extend_from_slice(&[Some(66), None, Some(67), None]); + filter_values.extend_from_slice(&[false, true, true, true]); + let a = Int32Array::from(data_values); + let b = BooleanArray::from(filter_values); + let c = filter(&a, &b).unwrap(); + let d = c.as_ref().as_any().downcast_ref::().unwrap(); + assert_eq!(67, d.len()); + assert_eq!(3, d.null_count()); + assert_eq!(1, d.value(0)); + assert_eq!(true, d.is_null(1)); + assert_eq!(64, d.value(63)); + assert_eq!(true, d.is_null(64)); + assert_eq!(67, d.value(65)); + } + #[test] fn test_filter_string_array() { let a = StringArray::from(vec!["hello", " ", "world", "!"]); @@ -273,4 +666,45 @@ mod tests { assert_eq!(1, d.len()); assert_eq!(true, d.is_null(0)); } + + #[test] + fn test_filter_array_slice_with_null() { + let a_slice = + Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4); + let a = a_slice.as_ref(); + let b = BooleanArray::from(vec![true, false, false, true]); + // filtering with sliced filter array is not currently supported + // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4); + // let b = b_slice.as_any().downcast_ref().unwrap(); + let c = filter(a, &b).unwrap(); + let d = c.as_ref().as_any().downcast_ref::().unwrap(); + assert_eq!(2, d.len()); + assert_eq!(true, d.is_null(0)); + assert_eq!(false, d.is_null(1)); + assert_eq!(9, d.value(1)); + } + + #[test] + fn test_filter_dictionary_array() { + let values = vec![Some("hello"), None, Some("world"), Some("!")]; + let a: Int8DictionaryArray = values.iter().map(|&x| x).collect(); + let b = BooleanArray::from(vec![false, true, true, false]); + let c = filter(&a, &b).unwrap(); + let d = c + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let value_array = d.values(); + let values = value_array.as_any().downcast_ref::().unwrap(); + // values are cloned in the filtered dictionary array + assert_eq!(3, values.len()); + // but keys are filtered + assert_eq!(2, d.len()); + assert_eq!(true, d.is_null(0)); + assert_eq!( + "world", + values.value(d.keys().nth(1).unwrap().unwrap() as usize) + ); + } } diff --git a/rust/arrow/src/util/bit_util.rs b/rust/arrow/src/util/bit_util.rs index a2ada2c0323..d8ffa6f19c5 100644 --- a/rust/arrow/src/util/bit_util.rs +++ b/rust/arrow/src/util/bit_util.rs @@ -68,7 +68,7 @@ pub unsafe fn get_bit_raw(data: *const u8, i: usize) -> bool { /// Sets bit at position `i` for `data` #[inline] pub fn set_bit(data: &mut [u8], i: usize) { - data[i >> 3] |= BIT_MASK[i & 7] + data[i >> 3] |= BIT_MASK[i & 7]; } /// Sets bit at position `i` for `data` @@ -79,7 +79,24 @@ pub fn set_bit(data: &mut [u8], i: usize) { /// responsible to guarantee that `i` is within bounds. #[inline] pub unsafe fn set_bit_raw(data: *mut u8, i: usize) { - *data.add(i >> 3) |= BIT_MASK[i & 7] + *data.add(i >> 3) |= BIT_MASK[i & 7]; +} + +/// Sets bit at position `i` for `data` to 0 +#[inline] +pub fn unset_bit(data: &mut [u8], i: usize) { + data[i >> 3] ^= BIT_MASK[i & 7]; +} + +/// Sets bit at position `i` for `data` to 0 +/// +/// # Safety +/// +/// Note this doesn't do any bound checking, for performance reason. The caller is +/// responsible to guarantee that `i` is within bounds. +#[inline] +pub unsafe fn unset_bit_raw(data: *mut u8, i: usize) { + *data.add(i >> 3) ^= BIT_MASK[i & 7]; } /// Sets bits in the non-inclusive range `start..end` for `data` @@ -257,6 +274,17 @@ mod tests { assert_eq!([0b00100101], b); } + #[test] + fn test_unset_bit() { + let mut b = [0b11111111]; + unset_bit(&mut b, 0); + assert_eq!([0b11111110], b); + unset_bit(&mut b, 2); + assert_eq!([0b11111010], b); + unset_bit(&mut b, 5); + assert_eq!([0b11011010], b); + } + #[test] fn test_set_bit_raw() { const NUM_BYTE: usize = 10; @@ -281,6 +309,30 @@ mod tests { } } + #[test] + fn test_unset_bit_raw() { + const NUM_BYTE: usize = 10; + let mut buf = vec![255; NUM_BYTE]; + let mut expected = vec![]; + let mut rng = thread_rng(); + for i in 0..8 * NUM_BYTE { + let b = rng.gen_bool(0.5); + expected.push(b); + if !b { + unsafe { + unset_bit_raw(buf.as_mut_ptr(), i); + } + } + } + + let raw_ptr = buf.as_ptr(); + for (i, b) in expected.iter().enumerate() { + unsafe { + assert_eq!(*b, get_bit_raw(raw_ptr, i)); + } + } + } + #[test] fn test_set_bits_raw() { const NUM_BYTE: usize = 64;