diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index 954715d0e5a9c..2efb1778a1d4d 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -15,264 +15,950 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ - Array, ArrayRef, Float32Array, Int16Array, Int32Array, StringArray, StringViewArray, - TimestampNanosecondArray, UInt8Array, -}; -use arrow::datatypes::{Field, Schema}; +//! Focused benchmarks for InList optimizations +//! +//! This benchmark file provides targeted coverage of each optimization strategy +//! with controlled parameters to ensure statistical robustness: +//! +//! - **Controlled match rates**: Tests both "found" and "not found" code paths +//! - **List size scaling**: Measures performance across different list sizes +//! - **Strategy coverage**: Each optimization has dedicated benchmarks +//! - **Reinterpret coverage**: Tests types that use zero-copy reinterpretation +//! - **Stage 2 stress testing**: Prefix-collision strings for two-stage filters +//! - **Null handling**: Tests null short-circuit optimization paths +//! +//! # Optimization Coverage +//! +//! | Strategy | Types | Threshold | List Sizes Tested | +//! |----------|-------|-----------|-------------------| +//! | BitmapFilter (stack) | UInt8 | always | 4, 16 | +//! | BitmapFilter (heap) | Int16 | always | 4, 64, 256 | +//! | BranchlessFilter | Int32, Float32 | ≤32 | 4, 32 | +//! | DirectProbeFilter | Int32, Float32 | >32 | 64, 256 | +//! | BranchlessFilter | Int64, TimestampNs | ≤16 | 4, 16 | +//! | DirectProbeFilter | Int64, TimestampNs | >16 | 32, 128 | +//! | Utf8TwoStageFilter | Utf8 | always | 4, 64, 256 | +//! | ByteViewMaskedFilter | Utf8View | always | 4, 16, 64, 256 | + +use arrow::array::*; +use arrow::datatypes::{Field, Int32Type, Schema}; use arrow::record_batch::RecordBatch; -use criterion::{Criterion, criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::{col, in_list, lit}; use rand::distr::Alphanumeric; use rand::prelude::*; -use std::any::TypeId; -use std::hint::black_box; use std::sync::Arc; -use std::time::Duration; -/// Measures how long `in_list(col("a"), exprs)` takes to evaluate against a single RecordBatch. -fn do_bench(c: &mut Criterion, name: &str, values: ArrayRef, exprs: &[ScalarValue]) { +const ARRAY_SIZE: usize = 8192; + +/// Match rates to test both code paths (miss-heavy and balanced) +const MATCH_RATES: [u32; 2] = [0, 50]; + +// ============================================================================= +// NUMERIC BENCHMARK HELPERS +// ============================================================================= + +/// Configuration for numeric benchmarks, grouping test parameters. +struct NumericBenchConfig { + list_size: usize, + match_rate: f64, + null_rate: f64, + make_value: fn(&mut StdRng) -> T, + to_scalar: fn(T) -> ScalarValue, + negated: bool, +} + +impl NumericBenchConfig { + fn new( + list_size: usize, + match_rate: f64, + make_value: fn(&mut StdRng) -> T, + to_scalar: fn(T) -> ScalarValue, + ) -> Self { + Self { + list_size, + match_rate, + null_rate: 0.0, + make_value, + to_scalar, + negated: false, + } + } + + fn with_null_rate(mut self, null_rate: f64) -> Self { + self.null_rate = null_rate; + self + } + + fn with_negated(mut self) -> Self { + self.negated = true; + self + } +} + +/// Creates and runs a benchmark for numeric types with controlled match rate. +/// Uses a seed derived from list_size to avoid subset correlation between sizes. +fn bench_numeric( + c: &mut Criterion, + group: &str, + name: &str, + cfg: &NumericBenchConfig, +) where + T: Clone, + A: Array + FromIterator> + 'static, +{ + // Use different seed per list_size to avoid subset correlation + let seed = 0xDEAD_BEEF_u64.wrapping_add(cfg.list_size as u64 * 0x1234_5678); + let mut rng = StdRng::seed_from_u64(seed); + + // Generate IN list values + let haystack: Vec = (0..cfg.list_size) + .map(|_| (cfg.make_value)(&mut rng)) + .collect(); + + // Generate array with controlled match rate and null rate + let values: A = (0..ARRAY_SIZE) + .map(|_| { + if cfg.null_rate > 0.0 && rng.random_bool(cfg.null_rate) { + None + } else if !haystack.is_empty() && rng.random_bool(cfg.match_rate) { + Some(haystack.choose(&mut rng).unwrap().clone()) + } else { + Some((cfg.make_value)(&mut rng)) + } + }) + .collect(); + let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]); - let exprs = exprs.iter().map(|s| lit(s.clone())).collect(); - let expr = in_list(col("a", &schema).unwrap(), exprs, &false, &schema).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![values]).unwrap(); + let exprs: Vec<_> = haystack + .iter() + .map(|v: &T| lit((cfg.to_scalar)(v.clone()))) + .collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &cfg.negated, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values) as ArrayRef]) + .unwrap(); - c.bench_function(name, |b| { - b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + c.bench_with_input(BenchmarkId::new(group, name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) }); } -/// Generates a random alphanumeric string of the specified length. +// ============================================================================= +// STRING BENCHMARK HELPERS +// ============================================================================= + fn random_string(rng: &mut StdRng, len: usize) -> String { - let value = rng.sample_iter(&Alphanumeric).take(len).collect(); - String::from_utf8(value).unwrap() + String::from_utf8(rng.sample_iter(&Alphanumeric).take(len).collect()).unwrap() +} + +/// Creates a set of strings that share a common prefix but differ in suffix. +/// Uses random alphanumeric suffix to avoid bench-maxing on numeric patterns. +fn strings_with_shared_prefix( + rng: &mut StdRng, + count: usize, + prefix_len: usize, +) -> Vec { + let prefix = random_string(rng, prefix_len); + (0..count) + .map(|_| format!("{}{}", prefix, random_string(rng, 8))) // prefix + random 8-char suffix + .collect() +} + +/// Configuration for string benchmarks, grouping test parameters. +struct StringBenchConfig { + list_size: usize, + match_rate: f64, + null_rate: f64, + string_len: usize, + to_scalar: fn(String) -> ScalarValue, + negated: bool, } -const IN_LIST_LENGTHS: [usize; 4] = [3, 8, 28, 100]; -const NULL_PERCENTS: [f64; 2] = [0., 0.2]; -const STRING_LENGTHS: [usize; 3] = [3, 12, 100]; -const ARRAY_LENGTH: usize = 8192; - -/// Mixed string lengths for realistic benchmarks. -/// ~50% short (≤12 bytes), ~50% long (>12 bytes). -const MIXED_STRING_LENGTHS: &[usize] = &[3, 6, 9, 12, 16, 20, 25, 30]; - -/// Returns a friendly type name for the array type. -fn array_type_name() -> &'static str { - let id = TypeId::of::(); - if id == TypeId::of::() { - "Utf8" - } else if id == TypeId::of::() { - "Utf8View" - } else if id == TypeId::of::() { - "Float32" - } else if id == TypeId::of::() { - "Int16" - } else if id == TypeId::of::() { - "Int32" - } else if id == TypeId::of::() { - "TimestampNs" - } else if id == TypeId::of::() { - "UInt8" - } else { - "Unknown" +impl StringBenchConfig { + fn new( + list_size: usize, + match_rate: f64, + string_len: usize, + to_scalar: fn(String) -> ScalarValue, + ) -> Self { + Self { + list_size, + match_rate, + null_rate: 0.0, + string_len, + to_scalar, + negated: false, + } + } + + fn with_null_rate(mut self, null_rate: f64) -> Self { + self.null_rate = null_rate; + self + } + + fn with_negated(mut self) -> Self { + self.negated = true; + self } } -/// Builds a benchmark name from array type, list size, and null percentage. -fn bench_name(in_list_length: usize, null_percent: f64) -> String { - format!( - "in_list/{}/list={in_list_length}/nulls={}%", - array_type_name::(), - (null_percent * 100.0) as u32 - ) +/// Creates and runs a benchmark for string types with controlled match rate. +/// Uses a seed derived from list_size and string_len to avoid correlation. +fn bench_string(c: &mut Criterion, group: &str, name: &str, cfg: &StringBenchConfig) +where + A: Array + FromIterator> + 'static, +{ + // Use different seed per (list_size, string_len) to avoid correlation + let seed = 0xCAFE_BABE_u64 + .wrapping_add(cfg.list_size as u64 * 0x1111) + .wrapping_add(cfg.string_len as u64 * 0x2222); + let mut rng = StdRng::seed_from_u64(seed); + + // Generate IN list values + let haystack: Vec = (0..cfg.list_size) + .map(|_| random_string(&mut rng, cfg.string_len)) + .collect(); + + // Generate array with controlled match rate and null rate + let values: A = (0..ARRAY_SIZE) + .map(|_| { + if cfg.null_rate > 0.0 && rng.random_bool(cfg.null_rate) { + None + } else if !haystack.is_empty() && rng.random_bool(cfg.match_rate) { + Some(haystack.choose(&mut rng).unwrap().clone()) + } else { + Some(random_string(&mut rng, cfg.string_len)) + } + }) + .collect(); + + let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]); + let exprs: Vec<_> = haystack + .iter() + .map(|v| lit((cfg.to_scalar)(v.clone()))) + .collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &cfg.negated, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values) as ArrayRef]) + .unwrap(); + + c.bench_with_input(BenchmarkId::new(group, name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) + }); } -/// Runs in_list benchmarks for a string array type across all list-size × null-ratio × string-length combinations. -fn bench_string_type( +/// Benchmarks strings with shared prefixes to stress Stage 2 of two-stage filters. +/// Uses variable prefix lengths and random suffixes to avoid bench-maxing. +fn bench_string_prefix_collision( c: &mut Criterion, - rng: &mut StdRng, - make_scalar: fn(String) -> ScalarValue, + group: &str, + name: &str, + list_size: usize, + match_rate: f64, + prefix_len: usize, + to_scalar: fn(String) -> ScalarValue, ) where A: Array + FromIterator> + 'static, { - for in_list_length in IN_LIST_LENGTHS { - for null_percent in NULL_PERCENTS { - for string_length in STRING_LENGTHS { - let values: A = (0..ARRAY_LENGTH) - .map(|_| { - rng.random_bool(1.0 - null_percent) - .then(|| random_string(rng, string_length)) - }) - .collect(); - - let in_list: Vec<_> = (0..in_list_length) - .map(|_| make_scalar(random_string(rng, string_length))) - .collect(); - - do_bench( - c, - &format!( - "{}/str={string_length}", - bench_name::(in_list_length, null_percent) - ), - Arc::new(values), - &in_list, - ) - } - } - } + let seed = 0xFEED_FACE_u64 + .wrapping_add(list_size as u64 * 0x3333) + .wrapping_add(prefix_len as u64 * 0x4444); + let mut rng = StdRng::seed_from_u64(seed); + + // Generate IN list with shared prefix (forces Stage 2) + let haystack = strings_with_shared_prefix(&mut rng, list_size, prefix_len); + + // Generate non-matching strings with SAME prefix (will pass Stage 1, fail Stage 2) + let non_match_pool = strings_with_shared_prefix(&mut rng, 100, prefix_len); + + // Generate array with controlled match rate + let values: A = (0..ARRAY_SIZE) + .map(|_| { + Some(if !haystack.is_empty() && rng.random_bool(match_rate) { + haystack.choose(&mut rng).unwrap().clone() + } else { + non_match_pool.choose(&mut rng).unwrap().clone() + }) + }) + .collect(); + + let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]); + let exprs: Vec<_> = haystack.iter().map(|v| lit(to_scalar(v.clone()))).collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &false, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values) as ArrayRef]) + .unwrap(); + + c.bench_with_input(BenchmarkId::new(group, name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) + }); } -/// Runs in_list benchmarks for a numeric array type across all list-size × null-ratio combinations. -fn bench_numeric_type( +/// Benchmarks mixed-length strings (some short ≤12, some long >12). +/// Tests the two-stage filter with realistic length distribution. +fn bench_string_mixed_lengths( c: &mut Criterion, - rng: &mut StdRng, - mut gen_value: impl FnMut(&mut StdRng) -> T, - make_scalar: fn(T) -> ScalarValue, + group: &str, + name: &str, + list_size: usize, + match_rate: f64, + to_scalar: fn(String) -> ScalarValue, ) where - A: Array + FromIterator> + 'static, + A: Array + FromIterator> + 'static, { - for in_list_length in IN_LIST_LENGTHS { - for null_percent in NULL_PERCENTS { - let values: A = (0..ARRAY_LENGTH) - .map(|_| rng.random_bool(1.0 - null_percent).then(|| gen_value(rng))) - .collect(); + let seed = 0xABCD_EF01_u64.wrapping_add(list_size as u64 * 0x5555); + let mut rng = StdRng::seed_from_u64(seed); + + // Mixed lengths: some short (≤12), some long (>12) + let lengths = [4, 8, 12, 16, 20, 24]; + + // Generate IN list with mixed lengths + let haystack: Vec = (0..list_size) + .map(|_| { + let len = *lengths.choose(&mut rng).unwrap(); + random_string(&mut rng, len) + }) + .collect(); + + // Generate array with controlled match rate and mixed lengths + let values: A = (0..ARRAY_SIZE) + .map(|_| { + Some(if !haystack.is_empty() && rng.random_bool(match_rate) { + haystack.choose(&mut rng).unwrap().clone() + } else { + let len = *lengths.choose(&mut rng).unwrap(); + random_string(&mut rng, len) + }) + }) + .collect(); + + let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]); + let exprs: Vec<_> = haystack.iter().map(|v| lit(to_scalar(v.clone()))).collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &false, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values) as ArrayRef]) + .unwrap(); + + c.bench_with_input(BenchmarkId::new(group, name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) + }); +} - let in_list: Vec<_> = (0..in_list_length) - .map(|_| make_scalar(gen_value(rng))) - .collect(); +// ============================================================================= +// BITMAP FILTER BENCHMARKS (UInt8, Int16) +// ============================================================================= + +fn bench_bitmap(c: &mut Criterion) { + // UInt8: 32-byte stack-allocated bitmap + // NOTE: With 256 possible values, list_size=16 covers 6.25% of value space, + // so even "match=0%" has ~6% accidental matches from random data. + for list_size in [4, 16] { + for match_pct in MATCH_RATES { + bench_numeric::( + c, + "bitmap", + &format!("u8/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random(), + |v| ScalarValue::UInt8(Some(v)), + ), + ); + } + } - do_bench( + // Int16: 8KB heap-allocated bitmap (via zero-copy reinterpret) + for list_size in [4, 64, 256] { + for match_pct in MATCH_RATES { + bench_numeric::( c, - &bench_name::(in_list_length, null_percent), - Arc::new(values), - &in_list, + "bitmap", + &format!("i16/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random(), + |v| ScalarValue::Int16(Some(v)), + ), ); } } } -/// Generates a random string with a length chosen from MIXED_STRING_LENGTHS. -fn random_mixed_length_string(rng: &mut StdRng) -> String { - let len = *MIXED_STRING_LENGTHS.choose(rng).unwrap(); - random_string(rng, len) +// ============================================================================= +// PRIMITIVE BENCHMARKS (Branchless vs Hash) +// ============================================================================= + +fn bench_primitive(c: &mut Criterion) { + // Int32: branchless threshold is 32 + for list_size in [4, 32, 64, 256] { + let strategy = if list_size <= 32 { + "branchless" + } else { + "hash" + }; + for match_pct in MATCH_RATES { + bench_numeric::( + c, + "primitive", + &format!("i32/{strategy}/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ), + ); + } + } + + // Int64: branchless threshold is 16 + for list_size in [4, 16, 32, 128] { + let strategy = if list_size <= 16 { + "branchless" + } else { + "hash" + }; + for match_pct in MATCH_RATES { + bench_numeric::( + c, + "primitive", + &format!("i64/{strategy}/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random(), + |v| ScalarValue::Int64(Some(v)), + ), + ); + } + } + + // NOT IN benchmark: test negated path + bench_numeric::( + c, + "primitive", + "i32/branchless/list=16/match=50%/NOT_IN", + &NumericBenchConfig::new( + 16, + 0.5, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ) + .with_negated(), + ); } -/// Benchmarks realistic mixed-length IN list scenario. -/// -/// Tests with: -/// - Mixed short (≤12 bytes) and long (>12 bytes) strings in the IN list -/// - Varying prefixes (fully random strings) -/// - Configurable match rate (% of values that are in the IN list) -/// - Various IN list sizes (3, 8, 28, 100) -fn bench_realistic_mixed_strings( - c: &mut Criterion, - rng: &mut StdRng, - make_scalar: fn(String) -> ScalarValue, -) where - A: Array + FromIterator> + 'static, -{ - for in_list_length in IN_LIST_LENGTHS { - for match_percent in [0.0, 0.25, 0.75] { - for null_percent in NULL_PERCENTS { - // Generate IN list with mixed-length random strings - let in_list_strings: Vec = (0..in_list_length) - .map(|_| random_mixed_length_string(rng)) - .collect(); - - let in_list: Vec<_> = in_list_strings - .iter() - .map(|s| make_scalar(s.clone())) - .collect(); - - // Generate values array with controlled match rate - let values: A = (0..ARRAY_LENGTH) - .map(|_| { - if !rng.random_bool(1.0 - null_percent) { - None - } else if rng.random_bool(match_percent) { - // Pick from IN list (will match) - Some(in_list_strings.choose(rng).unwrap().clone()) - } else { - // Generate new random string (unlikely to match) - Some(random_mixed_length_string(rng)) - } - }) - .collect(); - - do_bench( - c, - &format!( - "in_list/{}/mixed/list={}/match={}%/nulls={}%", - array_type_name::(), - in_list_length, - (match_percent * 100.0) as u32, - (null_percent * 100.0) as u32 - ), - Arc::new(values), - &in_list, - ); - } +// ============================================================================= +// REINTERPRETED TYPE BENCHMARKS (Float32, TimestampNs) +// ============================================================================= + +fn bench_reinterpret(c: &mut Criterion) { + // Float32: reinterpreted as u32, uses same branchless/hash strategies + // Threshold is 32 (same as Int32) + for list_size in [4, 32, 64] { + let strategy = if list_size <= 32 { + "branchless" + } else { + "hash" + }; + for match_pct in MATCH_RATES { + bench_numeric::( + c, + "reinterpret", + &format!("f32/{strategy}/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random::() * 1000.0, + |v| ScalarValue::Float32(Some(v)), + ), + ); + } + } + + // TimestampNanosecond: reinterpreted as i64, threshold is 16 + for list_size in [4, 16, 32] { + let strategy = if list_size <= 16 { + "branchless" + } else { + "hash" + }; + for match_pct in MATCH_RATES { + bench_numeric::( + c, + "reinterpret", + &format!("timestamp_ns/{strategy}/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random::().abs(), + |v| ScalarValue::TimestampNanosecond(Some(v), None), + ), + ); + } + } +} + +// ============================================================================= +// UTF8 TWO-STAGE FILTER BENCHMARKS +// ============================================================================= + +fn bench_utf8(c: &mut Criterion) { + let to_scalar: fn(String) -> ScalarValue = |s| ScalarValue::Utf8(Some(s)); + + // Short strings (8 bytes < 12): Stage 1 definitive + for list_size in [4, 64, 256] { + for match_pct in MATCH_RATES { + bench_string::( + c, + "utf8", + &format!("short_8b/list={list_size}/match={match_pct}%"), + &StringBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + 8, + to_scalar, + ), + ); + } + } + + // Long strings (24 bytes > 12): hits Stage 2 + for list_size in [4, 64, 256] { + for match_pct in MATCH_RATES { + bench_string::( + c, + "utf8", + &format!("long_24b/list={list_size}/match={match_pct}%"), + &StringBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + 24, + to_scalar, + ), + ); + } + } + + // Mixed-length strings: realistic distribution + for list_size in [16, 64] { + for match_pct in MATCH_RATES { + bench_string_mixed_lengths::( + c, + "utf8", + &format!("mixed_len/list={list_size}/match={match_pct}%"), + list_size, + match_pct as f64 / 100.0, + to_scalar, + ); + } + } + + // Prefix collision: stresses Stage 2 comparison + bench_string_prefix_collision::( + c, + "utf8", + "prefix_collision/pfx=12/list=32/match=50%", + 32, + 0.5, + 12, + to_scalar, + ); + + // NOT IN benchmark + bench_string::( + c, + "utf8", + "short_8b/list=16/match=50%/NOT_IN", + &StringBenchConfig::new(16, 0.5, 8, to_scalar).with_negated(), + ); +} + +// ============================================================================= +// UTF8VIEW TWO-STAGE FILTER BENCHMARKS +// ============================================================================= + +fn bench_utf8view(c: &mut Criterion) { + let to_scalar: fn(String) -> ScalarValue = |s| ScalarValue::Utf8View(Some(s)); + + // Short strings (8 bytes ≤ 12): inline storage path + for list_size in [4, 16, 64, 256] { + for match_pct in MATCH_RATES { + bench_string::( + c, + "utf8view", + &format!("short_8b/list={list_size}/match={match_pct}%"), + &StringBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + 8, + to_scalar, + ), + ); + } + } + + // Boundary strings (exactly 12 bytes): max inline size + for list_size in [16, 64] { + for match_pct in MATCH_RATES { + bench_string::( + c, + "utf8view", + &format!("boundary_12b/list={list_size}/match={match_pct}%"), + &StringBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + 12, + to_scalar, + ), + ); + } + } + + // Long strings (24 bytes > 12): out-of-line storage, two-stage filter + for list_size in [4, 16, 64, 256] { + for match_pct in MATCH_RATES { + bench_string::( + c, + "utf8view", + &format!("long_24b/list={list_size}/match={match_pct}%"), + &StringBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + 24, + to_scalar, + ), + ); + } + } + + // Mixed-length strings: realistic distribution + for list_size in [16, 64] { + for match_pct in MATCH_RATES { + bench_string_mixed_lengths::( + c, + "utf8view", + &format!("mixed_len/list={list_size}/match={match_pct}%"), + list_size, + match_pct as f64 / 100.0, + to_scalar, + ); + } + } + + // Prefix collision: stresses Stage 2 comparison with varying prefix lengths + for (prefix_len, list_size) in [(8, 16), (12, 32), (16, 64)] { + for match_pct in MATCH_RATES { + bench_string_prefix_collision::( + c, + "utf8view", + &format!( + "prefix_collision/pfx={prefix_len}/list={list_size}/match={match_pct}%" + ), + list_size, + match_pct as f64 / 100.0, + prefix_len, + to_scalar, + ); } } } -/// Entry point: registers in_list benchmarks for string and numeric array types. -fn criterion_benchmark(c: &mut Criterion) { - let mut rng = StdRng::seed_from_u64(120320); +// ============================================================================= +// DICTIONARY ARRAY BENCHMARKS +// ============================================================================= + +/// Helper to benchmark dictionary-encoded Int32 arrays +fn bench_dict_int32( + c: &mut Criterion, + name: &str, + dict_size: usize, + list_size: usize, + negated: bool, +) { + let seed = 0xD1C7_0000_u64 + .wrapping_add(dict_size as u64 * 0x1111) + .wrapping_add(list_size as u64 * 0x2222); + let mut rng = StdRng::seed_from_u64(seed); + + let dict_values: Vec = (0..dict_size).map(|_| rng.random()).collect(); + let haystack: Vec = dict_values.iter().take(list_size).cloned().collect(); + + let indices: Vec = (0..ARRAY_SIZE) + .map(|_| rng.random_range(0..dict_size as i32)) + .collect(); + let indices_array = Int32Array::from(indices); + let values_array = Int32Array::from(dict_values); + let dict_array = + DictionaryArray::::try_new(indices_array, Arc::new(values_array)) + .unwrap(); - // Benchmarks for string array types (Utf8, Utf8View) - bench_string_type::(c, &mut rng, |s| ScalarValue::Utf8(Some(s))); - bench_string_type::(c, &mut rng, |s| ScalarValue::Utf8View(Some(s))); + let schema = Schema::new(vec![Field::new("a", dict_array.data_type().clone(), true)]); + let exprs: Vec<_> = haystack + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &negated, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict_array) as ArrayRef]) + .unwrap(); - // Realistic mixed-length string benchmarks (TPC-H style) - bench_realistic_mixed_strings::(c, &mut rng, |s| { - ScalarValue::Utf8(Some(s)) + c.bench_with_input(BenchmarkId::new("dictionary", name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) }); - bench_realistic_mixed_strings::(c, &mut rng, |s| { - ScalarValue::Utf8View(Some(s)) +} + +/// Helper to benchmark dictionary-encoded string arrays +fn bench_dict_string( + c: &mut Criterion, + name: &str, + dict_size: usize, + list_size: usize, + string_len: usize, +) { + let seed = 0xD1C7_5778_u64 + .wrapping_add(dict_size as u64 * 0x3333) + .wrapping_add(string_len as u64 * 0x4444); + let mut rng = StdRng::seed_from_u64(seed); + + let dict_values: Vec = (0..dict_size) + .map(|_| random_string(&mut rng, string_len)) + .collect(); + let haystack: Vec = dict_values.iter().take(list_size).cloned().collect(); + + let indices: Vec = (0..ARRAY_SIZE) + .map(|_| rng.random_range(0..dict_size as i32)) + .collect(); + let indices_array = Int32Array::from(indices); + let values_array = StringArray::from(dict_values); + let dict_array = + DictionaryArray::::try_new(indices_array, Arc::new(values_array)) + .unwrap(); + + let schema = Schema::new(vec![Field::new("a", dict_array.data_type().clone(), true)]); + let exprs: Vec<_> = haystack + .iter() + .map(|v| lit(ScalarValue::Utf8(Some(v.clone())))) + .collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &false, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict_array) as ArrayRef]) + .unwrap(); + + c.bench_with_input(BenchmarkId::new("dictionary", name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) }); +} + +fn bench_dictionary(c: &mut Criterion) { + // Int32 dictionary: varying list sizes (tests branchless vs hash on values) + // Dictionary with 100 unique values + for list_size in [4, 16, 64] { + bench_dict_int32( + c, + &format!("i32/dict=100/list={list_size}"), + 100, + list_size, + false, + ); + } + + // Int32 dictionary: varying dictionary cardinality + for dict_size in [10, 1000] { + bench_dict_int32( + c, + &format!("i32/dict={dict_size}/list=16"), + dict_size, + 16, + false, + ); + } - // Benchmarks for numeric types - bench_numeric_type::( + // Int32 dictionary: NOT IN path + bench_dict_int32(c, "i32/dict=100/list=16/NOT_IN", 100, 16, true); + + // String dictionary: short strings (≤12 bytes, common for codes/categories) + for list_size in [8, 32] { + bench_dict_string( + c, + &format!("utf8_short/dict=50/list={list_size}"), + 50, + list_size, + 8, + ); + } + + // String dictionary: long strings (>12 bytes) + bench_dict_string(c, "utf8_long/dict=100/list=16", 100, 16, 24); + + // String dictionary: large cardinality (realistic category counts) + bench_dict_string(c, "utf8_short/dict=500/list=20", 500, 20, 10); +} + +// ============================================================================= +// NULL HANDLING BENCHMARKS +// ============================================================================= +// +// Tests null short-circuit optimization paths in: +// - build_in_list_result: computes contains for ALL positions, masks via bitmap ops +// - build_in_list_result_with_null_shortcircuit: skips contains for null positions +// +// The shortcircuit is beneficial for expensive contains checks (strings) but +// adds branch overhead for cheap checks (primitives). + +fn bench_nulls(c: &mut Criterion) { + // ========================================================================= + // PRIMITIVE TYPES: Tests build_in_list_result (no shortcircuit) + // ========================================================================= + + // BitmapFilter with nulls + bench_numeric::( c, - &mut rng, - |rng| rng.random(), - |v| ScalarValue::UInt8(Some(v)), + "nulls", + "bitmap/u8/list=16/match=50%/nulls=20%", + &NumericBenchConfig::new( + 16, + 0.5, + |rng| rng.random(), + |v| ScalarValue::UInt8(Some(v)), + ) + .with_null_rate(0.2), ); - bench_numeric_type::( + + // BranchlessFilter with nulls + bench_numeric::( c, - &mut rng, - |rng| rng.random(), - |v| ScalarValue::Int16(Some(v)), + "nulls", + "branchless/i32/list=16/match=50%/nulls=20%", + &NumericBenchConfig::new( + 16, + 0.5, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ) + .with_null_rate(0.2), ); - bench_numeric_type::( + + // DirectProbeFilter with nulls + bench_numeric::( c, - &mut rng, - |rng| rng.random(), - |v| ScalarValue::Float32(Some(v)), + "nulls", + "hash/i32/list=64/match=50%/nulls=20%", + &NumericBenchConfig::new( + 64, + 0.5, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ) + .with_null_rate(0.2), ); - bench_numeric_type::( + + // ========================================================================= + // STRING TYPES: Tests build_in_list_result_with_null_shortcircuit + // ========================================================================= + + let utf8_scalar: fn(String) -> ScalarValue = |s| ScalarValue::Utf8(Some(s)); + let utf8view_scalar: fn(String) -> ScalarValue = |s| ScalarValue::Utf8View(Some(s)); + + // Utf8TwoStageFilter with nulls (short strings) + bench_string::( c, - &mut rng, - |rng| rng.random(), - |v| ScalarValue::Int32(Some(v)), + "nulls", + "utf8/short_8b/list=16/match=50%/nulls=20%", + &StringBenchConfig::new(16, 0.5, 8, utf8_scalar).with_null_rate(0.2), ); - bench_numeric_type::( + + // Utf8TwoStageFilter with nulls (long strings - Stage 2) + bench_string::( + c, + "nulls", + "utf8/long_24b/list=16/match=50%/nulls=20%", + &StringBenchConfig::new(16, 0.5, 24, utf8_scalar).with_null_rate(0.2), + ); + + // ByteViewMaskedFilter with nulls (short strings - inline) + bench_string::( + c, + "nulls", + "utf8view/short_8b/list=16/match=50%/nulls=20%", + &StringBenchConfig::new(16, 0.5, 8, utf8view_scalar).with_null_rate(0.2), + ); + + // ByteViewMaskedFilter with nulls (long strings - out-of-line) + bench_string::( + c, + "nulls", + "utf8view/long_24b/list=16/match=50%/nulls=20%", + &StringBenchConfig::new(16, 0.5, 24, utf8view_scalar).with_null_rate(0.2), + ); + + // ========================================================================= + // NOT IN WITH NULLS: Tests negated path with null propagation + // ========================================================================= + + // Primitive NOT IN with nulls + bench_numeric::( c, - &mut rng, - |rng| rng.random(), - |v| ScalarValue::TimestampNanosecond(Some(v), None), + "nulls", + "branchless/i32/list=16/match=50%/nulls=20%/NOT_IN", + &NumericBenchConfig::new( + 16, + 0.5, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ) + .with_null_rate(0.2) + .with_negated(), + ); + + // String NOT IN with nulls + bench_string::( + c, + "nulls", + "utf8view/short_8b/list=16/match=50%/nulls=20%/NOT_IN", + &StringBenchConfig::new(16, 0.5, 8, utf8view_scalar) + .with_null_rate(0.2) + .with_negated(), + ); + + // ========================================================================= + // HIGH NULL RATE: Stress test null handling paths + // ========================================================================= + + // 50% nulls - half the array is null + bench_numeric::( + c, + "nulls", + "branchless/i32/list=16/match=50%/nulls=50%", + &NumericBenchConfig::new( + 16, + 0.5, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ) + .with_null_rate(0.5), + ); + + bench_string::( + c, + "nulls", + "utf8view/short_8b/list=16/match=50%/nulls=50%", + &StringBenchConfig::new(16, 0.5, 8, utf8view_scalar).with_null_rate(0.5), ); } +// ============================================================================= +// CRITERION SETUP +// ============================================================================= + criterion_group! { name = benches; - config = Criterion::default() - .warm_up_time(Duration::from_millis(100)) - .measurement_time(Duration::from_millis(500)); - targets = criterion_benchmark + config = Criterion::default(); + targets = bench_bitmap, bench_primitive, bench_reinterpret, bench_utf8, bench_utf8view, bench_dictionary, bench_nulls } + criterion_main!(benches); diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 5c2f1adcd0cf3..6250547cb8cae 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -17,6 +17,13 @@ //! Implementation of `InList` expressions: [`InListExpr`] +mod nested_filter; +mod primitive_filter; +mod result; +mod static_filter; +mod strategy; +mod transform; + use std::any::Any; use std::fmt::Debug; use std::hash::{Hash, Hasher}; @@ -27,28 +34,16 @@ use crate::physical_expr::physical_exprs_bag_equal; use arrow::array::*; use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::SortOptions; use arrow::compute::kernels::boolean::{not, or_kleene}; -use arrow::compute::{SortOptions, take}; use arrow::datatypes::*; -use arrow::util::bit_iterator::BitIndexIterator; -use datafusion_common::hash_utils::with_hashes; use datafusion_common::{ - DFSchema, HashSet, Result, ScalarValue, assert_or_internal_err, exec_datafusion_err, - exec_err, + DFSchema, Result, ScalarValue, assert_or_internal_err, exec_err, }; use datafusion_expr::{ColumnarValue, expr_vec_fmt}; -use ahash::RandomState; -use datafusion_common::HashMap; -use hashbrown::hash_map::RawEntryMut; - -/// Trait for InList static filters -trait StaticFilter { - fn null_count(&self) -> usize; - - /// Checks if values in `v` are contained in the filter - fn contains(&self, v: &dyn Array, negated: bool) -> Result; -} +use static_filter::StaticFilter; +use strategy::instantiate_static_filter; /// InList pub struct InListExpr { @@ -68,470 +63,6 @@ impl Debug for InListExpr { } } -/// Static filter for InList that stores the array and hash set for O(1) lookups -#[derive(Debug, Clone)] -struct ArrayStaticFilter { - in_array: ArrayRef, - state: RandomState, - /// Used to provide a lookup from value to in list index - /// - /// Note: usize::hash is not used, instead the raw entry - /// API is used to store entries w.r.t their value - map: HashMap, -} - -impl StaticFilter for ArrayStaticFilter { - fn null_count(&self) -> usize { - self.in_array.null_count() - } - - /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Null type comparisons always return null (SQL three-valued logic) - if v.data_type() == &DataType::Null - || self.in_array.data_type() == &DataType::Null - { - let nulls = NullBuffer::new_null(v.len()); - return Ok(BooleanArray::new( - BooleanBuffer::new_unset(v.len()), - Some(nulls), - )); - } - - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } - - let needle_nulls = v.logical_nulls(); - let needle_nulls = needle_nulls.as_ref(); - let haystack_has_nulls = self.in_array.null_count() != 0; - - with_hashes([v], &self.state, |hashes| { - let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; - Ok((0..v.len()) - .map(|i| { - // SQL three-valued logic: null IN (...) is always null - if needle_nulls.is_some_and(|nulls| nulls.is_null(i)) { - return None; - } - - let hash = hashes[i]; - let contains = self - .map - .raw_entry() - .from_hash(hash, |idx| cmp(i, *idx).is_eq()) - .is_some(); - - match contains { - true => Some(!negated), - false if haystack_has_nulls => None, - false => Some(negated), - } - }) - .collect()) - }) - } -} - -fn instantiate_static_filter( - in_array: ArrayRef, -) -> Result> { - match in_array.data_type() { - // Integer primitive types - DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), - DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), - DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), - DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), - DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), - DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), - DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), - DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), - // Float primitive types (use ordered wrappers for Hash/Eq) - DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)), - DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)), - _ => { - /* fall through to generic implementation for unsupported types (Struct, etc.) */ - Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) - } - } -} - -impl ArrayStaticFilter { - /// Computes a [`StaticFilter`] for the provided [`Array`] if there - /// are nulls present or there are more than the configured number of - /// elements. - /// - /// Note: This is split into a separate function as higher-rank trait bounds currently - /// cause type inference to misbehave - fn try_new(in_array: ArrayRef) -> Result { - // Null type has no natural order - return empty hash set - if in_array.data_type() == &DataType::Null { - return Ok(ArrayStaticFilter { - in_array, - state: RandomState::new(), - map: HashMap::with_hasher(()), - }); - } - - let state = RandomState::new(); - let mut map: HashMap = HashMap::with_hasher(()); - - with_hashes([&in_array], &state, |hashes| -> Result<()> { - let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?; - - let insert_value = |idx| { - let hash = hashes[idx]; - if let RawEntryMut::Vacant(v) = map - .raw_entry_mut() - .from_hash(hash, |x| cmp(*x, idx).is_eq()) - { - v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); - } - }; - - match in_array.nulls() { - Some(nulls) => { - BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) - .for_each(insert_value) - } - None => (0..in_array.len()).for_each(insert_value), - } - - Ok(()) - })?; - - Ok(Self { - in_array, - state, - map, - }) - } -} - -/// Wrapper for f32 that implements Hash and Eq using bit comparison. -/// This treats NaN values as equal to each other when they have the same bit pattern. -#[derive(Clone, Copy)] -struct OrderedFloat32(f32); - -impl Hash for OrderedFloat32 { - fn hash(&self, state: &mut H) { - self.0.to_ne_bytes().hash(state); - } -} - -impl PartialEq for OrderedFloat32 { - fn eq(&self, other: &Self) -> bool { - self.0.to_bits() == other.0.to_bits() - } -} - -impl Eq for OrderedFloat32 {} - -impl From for OrderedFloat32 { - fn from(v: f32) -> Self { - Self(v) - } -} - -/// Wrapper for f64 that implements Hash and Eq using bit comparison. -/// This treats NaN values as equal to each other when they have the same bit pattern. -#[derive(Clone, Copy)] -struct OrderedFloat64(f64); - -impl Hash for OrderedFloat64 { - fn hash(&self, state: &mut H) { - self.0.to_ne_bytes().hash(state); - } -} - -impl PartialEq for OrderedFloat64 { - fn eq(&self, other: &Self) -> bool { - self.0.to_bits() == other.0.to_bits() - } -} - -impl Eq for OrderedFloat64 {} - -impl From for OrderedFloat64 { - fn from(v: f64) -> Self { - Self(v) - } -} - -// Macro to generate specialized StaticFilter implementations for primitive types -macro_rules! primitive_static_filter { - ($Name:ident, $ArrowType:ty) => { - struct $Name { - null_count: usize, - values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>, - } - - impl $Name { - fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let mut values = HashSet::with_capacity(in_array.len()); - let null_count = in_array.null_count(); - - for v in in_array.iter().flatten() { - values.insert(v); - } - - Ok(Self { null_count, values }) - } - } - - impl StaticFilter for $Name { - fn null_count(&self) -> usize { - self.null_count - } - - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } - - let v = v - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let haystack_has_nulls = self.null_count > 0; - - let needle_values = v.values(); - let needle_nulls = v.nulls(); - let needle_has_nulls = v.null_count() > 0; - - // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: - // ("-" means the value doesn't affect the result) - // - // | needle_null | haystack_null | negated | in set? | result | - // |-------------|---------------|---------|---------|--------| - // | true | - | false | - | null | - // | true | - | true | - | null | - // | false | true | false | yes | true | - // | false | true | false | no | null | - // | false | true | true | yes | false | - // | false | true | true | no | null | - // | false | false | false | yes | true | - // | false | false | false | no | false | - // | false | false | true | yes | false | - // | false | false | true | no | true | - - // Compute the "contains" result using collect_bool (fast batched approach) - // This ignores nulls - we handle them separately - let contains_buffer = if negated { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - !self.values.contains(&needle_values[i]) - }) - } else { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - self.values.contains(&needle_values[i]) - }) - }; - - // Compute the null mask - // Output is null when: - // 1. needle value is null, OR - // 2. needle value is not in set AND haystack has nulls - let result_nulls = match (needle_has_nulls, haystack_has_nulls) { - (false, false) => { - // No nulls anywhere - None - } - (true, false) => { - // Only needle has nulls - just use needle's null mask - needle_nulls.cloned() - } - (false, true) => { - // Only haystack has nulls - result is null when value not in set - // Valid (not null) when original "in set" is true - // For NOT IN: contains_buffer = !original, so validity = !contains_buffer - let validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - Some(NullBuffer::new(validity)) - } - (true, true) => { - // Both have nulls - combine needle nulls with haystack-induced nulls - let needle_validity = needle_nulls.map(|n| n.inner().clone()) - .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); - - // Valid when original "in set" is true (see above) - let haystack_validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - - // Combined validity: valid only where both are valid - let combined_validity = &needle_validity & &haystack_validity; - Some(NullBuffer::new(combined_validity)) - } - }; - - Ok(BooleanArray::new(contains_buffer, result_nulls)) - } - } - }; -} - -// Generate specialized filters for all integer primitive types -primitive_static_filter!(Int8StaticFilter, Int8Type); -primitive_static_filter!(Int16StaticFilter, Int16Type); -primitive_static_filter!(Int32StaticFilter, Int32Type); -primitive_static_filter!(Int64StaticFilter, Int64Type); -primitive_static_filter!(UInt8StaticFilter, UInt8Type); -primitive_static_filter!(UInt16StaticFilter, UInt16Type); -primitive_static_filter!(UInt32StaticFilter, UInt32Type); -primitive_static_filter!(UInt64StaticFilter, UInt64Type); - -// Macro to generate specialized StaticFilter implementations for float types -// Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics -macro_rules! float_static_filter { - ($Name:ident, $ArrowType:ty, $OrderedType:ty) => { - struct $Name { - null_count: usize, - values: HashSet<$OrderedType>, - } - - impl $Name { - fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let mut values = HashSet::with_capacity(in_array.len()); - let null_count = in_array.null_count(); - - for v in in_array.iter().flatten() { - values.insert(<$OrderedType>::from(v)); - } - - Ok(Self { null_count, values }) - } - } - - impl StaticFilter for $Name { - fn null_count(&self) -> usize { - self.null_count - } - - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } - - let v = v - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let haystack_has_nulls = self.null_count > 0; - - let needle_values = v.values(); - let needle_nulls = v.nulls(); - let needle_has_nulls = v.null_count() > 0; - - // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: - // ("-" means the value doesn't affect the result) - // - // | needle_null | haystack_null | negated | in set? | result | - // |-------------|---------------|---------|---------|--------| - // | true | - | false | - | null | - // | true | - | true | - | null | - // | false | true | false | yes | true | - // | false | true | false | no | null | - // | false | true | true | yes | false | - // | false | true | true | no | null | - // | false | false | false | yes | true | - // | false | false | false | no | false | - // | false | false | true | yes | false | - // | false | false | true | no | true | - - // Compute the "contains" result using collect_bool (fast batched approach) - // This ignores nulls - we handle them separately - let contains_buffer = if negated { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - !self.values.contains(&<$OrderedType>::from(needle_values[i])) - }) - } else { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - self.values.contains(&<$OrderedType>::from(needle_values[i])) - }) - }; - - // Compute the null mask - // Output is null when: - // 1. needle value is null, OR - // 2. needle value is not in set AND haystack has nulls - let result_nulls = match (needle_has_nulls, haystack_has_nulls) { - (false, false) => { - // No nulls anywhere - None - } - (true, false) => { - // Only needle has nulls - just use needle's null mask - needle_nulls.cloned() - } - (false, true) => { - // Only haystack has nulls - result is null when value not in set - // Valid (not null) when original "in set" is true - // For NOT IN: contains_buffer = !original, so validity = !contains_buffer - let validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - Some(NullBuffer::new(validity)) - } - (true, true) => { - // Both have nulls - combine needle nulls with haystack-induced nulls - let needle_validity = needle_nulls.map(|n| n.inner().clone()) - .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); - - // Valid when original "in set" is true (see above) - let haystack_validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - - // Combined validity: valid only where both are valid - let combined_validity = &needle_validity & &haystack_validity; - Some(NullBuffer::new(combined_validity)) - } - }; - - Ok(BooleanArray::new(contains_buffer, result_nulls)) - } - } - }; -} - -// Generate specialized filters for float types using ordered wrappers -float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32); -float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64); - /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( list: &[Arc], diff --git a/datafusion/physical-expr/src/expressions/in_list/nested_filter.rs b/datafusion/physical-expr/src/expressions/in_list/nested_filter.rs new file mode 100644 index 0000000000000..091728eb0397c --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/nested_filter.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fallback filter for nested/complex types (List, Struct, Map, Union, etc.) + +use arrow::array::{ + Array, ArrayRef, BooleanArray, downcast_array, downcast_dictionary_array, + make_comparator, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::{SortOptions, take}; +use arrow::datatypes::DataType; +use arrow::util::bit_iterator::BitIndexIterator; +use datafusion_common::Result; +use datafusion_common::hash_utils::with_hashes; + +use ahash::RandomState; +use hashbrown::HashTable; + +use super::result::build_in_list_result; +use super::static_filter::StaticFilter; + +/// Fallback filter for nested/complex types (List, Struct, Map, Union, etc.) +/// +/// Uses dynamic comparator via `make_comparator` since these types don't have +/// a simple typed comparison. For primitive and byte array types, use the +/// specialized filters instead (PrimitiveFilter, ByteArrayFilter, etc.) +#[derive(Debug, Clone)] +pub(crate) struct NestedTypeFilter { + in_array: ArrayRef, + state: RandomState, + /// Stores indices into `in_array` for O(1) lookups. + table: HashTable, +} + +impl NestedTypeFilter { + /// Creates a filter for nested/complex array types. + /// + /// This filter uses dynamic comparison and should only be used for types + /// that don't have specialized filters (List, Struct, Map, Union). + pub(crate) fn try_new(in_array: ArrayRef) -> Result { + // Null type has no natural order - return empty hash set + if in_array.data_type() == &DataType::Null { + return Ok(Self { + in_array, + state: RandomState::new(), + table: HashTable::new(), + }); + } + + let state = RandomState::new(); + let table = Self::build_haystack_table(&in_array, &state)?; + + Ok(Self { + in_array, + state, + table, + }) + } + + /// Build a hash table from haystack values for O(1) lookups. + /// + /// Each unique non-null value's index is stored, keyed by its hash. + /// Uses dynamic comparison via `make_comparator` for complex types. + fn build_haystack_table( + haystack: &ArrayRef, + state: &RandomState, + ) -> Result> { + let mut table = HashTable::new(); + + with_hashes([haystack.as_ref()], state, |hashes| -> Result<()> { + let cmp = make_comparator(haystack, haystack, SortOptions::default())?; + + let insert_value = |idx| { + let hash = hashes[idx]; + // Only insert if not already present (deduplication) + if table.find(hash, |&x| cmp(x, idx).is_eq()).is_none() { + table.insert_unique(hash, idx, |&x| hashes[x]); + } + }; + + match haystack.nulls() { + Some(nulls) => { + BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) + .for_each(insert_value) + } + None => (0..haystack.len()).for_each(insert_value), + } + + Ok(()) + })?; + + Ok(table) + } + + /// Check which needle values exist in the haystack. + /// + /// Hashes each needle value and looks it up in the pre-built haystack table. + /// Uses dynamic comparison via `make_comparator` for complex types. + fn find_needles_in_haystack( + &self, + needles: &dyn Array, + negated: bool, + ) -> Result { + let needle_nulls = needles.logical_nulls(); + let haystack_has_nulls = self.in_array.null_count() != 0; + + with_hashes([needles], &self.state, |needle_hashes| { + let cmp = make_comparator(needles, &self.in_array, SortOptions::default())?; + + Ok(build_in_list_result( + needles.len(), + needle_nulls.as_ref(), + haystack_has_nulls, + negated, + #[inline(always)] + |i| { + let hash = needle_hashes[i]; + self.table.find(hash, |&idx| cmp(i, idx).is_eq()).is_some() + }, + )) + }) + } +} + +impl StaticFilter for NestedTypeFilter { + fn null_count(&self) -> usize { + self.in_array.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Null type comparisons always return null (SQL three-valued logic) + if v.data_type() == &DataType::Null + || self.in_array.data_type() == &DataType::Null + { + let nulls = NullBuffer::new_null(v.len()); + return Ok(BooleanArray::new( + BooleanBuffer::new_unset(v.len()), + Some(nulls), + )); + } + + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + + self.find_needles_in_haystack(v, negated) + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs new file mode 100644 index 0000000000000..efe74a4409dd1 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs @@ -0,0 +1,471 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimized primitive type filters for InList expressions +//! +//! This module provides high-performance membership testing for Arrow primitive types. + +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; +use arrow::datatypes::ArrowPrimitiveType; +use datafusion_common::{HashSet, Result, exec_datafusion_err}; + +use super::result::{build_in_list_result, handle_dictionary}; +use super::static_filter::StaticFilter; + +// ============================================================================= +// BITMAP FILTERS (O(1) lookup for u8/u16 via bit test) +// ============================================================================= + +/// Trait for bitmap storage (stack-allocated for u8, heap-allocated for u16). +pub(crate) trait BitmapStorage: Send + Sync { + fn new_zeroed() -> Self; + fn set_bit(&mut self, index: usize); + fn get_bit(&self, index: usize) -> bool; +} + +impl BitmapStorage for [u64; 4] { + #[inline] + fn new_zeroed() -> Self { + [0u64; 4] + } + #[inline] + fn set_bit(&mut self, index: usize) { + self[index / 64] |= 1u64 << (index % 64); + } + #[inline(always)] + fn get_bit(&self, index: usize) -> bool { + (self[index / 64] >> (index % 64)) & 1 != 0 + } +} + +impl BitmapStorage for Box<[u64; 1024]> { + #[inline] + fn new_zeroed() -> Self { + Box::new([0u64; 1024]) + } + #[inline] + fn set_bit(&mut self, index: usize) { + self[index / 64] |= 1u64 << (index % 64); + } + #[inline(always)] + fn get_bit(&self, index: usize) -> bool { + (self[index / 64] >> (index % 64)) & 1 != 0 + } +} + +/// Configuration trait for bitmap filters. +pub(crate) trait BitmapFilterConfig: Send + Sync + 'static { + type Native: arrow::datatypes::ArrowNativeType + Copy + Send + Sync; + type ArrowType: ArrowPrimitiveType; + type Storage: BitmapStorage; + + fn to_index(v: Self::Native) -> usize; +} + +/// Config for u8 bitmap (256 bits = 32 bytes, fits in cache line). +pub(crate) enum U8Config {} +impl BitmapFilterConfig for U8Config { + type Native = u8; + type ArrowType = arrow::datatypes::UInt8Type; + type Storage = [u64; 4]; + + #[inline(always)] + fn to_index(v: u8) -> usize { + v as usize + } +} + +/// Config for u16 bitmap (65536 bits = 8 KB, fits in L1 cache). +pub(crate) enum U16Config {} +impl BitmapFilterConfig for U16Config { + type Native = u16; + type ArrowType = arrow::datatypes::UInt16Type; + type Storage = Box<[u64; 1024]>; + + #[inline(always)] + fn to_index(v: u16) -> usize { + v as usize + } +} + +/// Bitmap filter for O(1) set membership via single bit test. +/// +/// For small integer types (u8/u16), bitmap lookup outperforms both branchless +/// and hashed approaches at all list sizes. +pub(crate) struct BitmapFilter { + null_count: usize, + bits: C::Storage, +} + +impl BitmapFilter { + pub(crate) fn try_new(in_array: &ArrayRef) -> Result { + let prim_array = + in_array.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("BitmapFilter: expected primitive array") + })?; + let mut bits = C::Storage::new_zeroed(); + for v in prim_array.iter().flatten() { + bits.set_bit(C::to_index(v)); + } + Ok(Self { + null_count: prim_array.null_count(), + bits, + }) + } + + #[inline(always)] + fn check(&self, needle: C::Native) -> bool { + self.bits.get_bit(C::to_index(needle)) + } + + /// Check membership using a raw values slice (zero-copy path for type reinterpretation). + #[inline] + pub(crate) fn contains_slice( + &self, + values: &[C::Native], + nulls: Option<&arrow::buffer::NullBuffer>, + negated: bool, + ) -> BooleanArray { + build_in_list_result(values.len(), nulls, self.null_count > 0, negated, |i| { + self.check(unsafe { *values.get_unchecked(i) }) + }) + } +} + +impl StaticFilter for BitmapFilter { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + let v = v.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("BitmapFilter: expected primitive array") + })?; + let input_values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + #[inline(always)] + |i| self.check(unsafe { *input_values.get_unchecked(i) }), + )) + } +} + +// ============================================================================= +// BRANCHLESS FILTER (Const Generic for Small Lists) +// ============================================================================= + +/// A branchless filter for very small IN lists (0-16 elements). +/// +/// Uses const generics to unroll the membership check into a fixed-size +/// comparison chain, outperforming hash lookups for small lists due to: +/// - No branching (uses bitwise OR to combine comparisons) +/// - Better CPU pipelining +/// - No hash computation overhead +pub(crate) struct BranchlessFilter { + null_count: usize, + values: [T::Native; N], +} + +impl BranchlessFilter +where + T::Native: Copy + PartialEq, +{ + /// Try to create a branchless filter if the array has exactly N non-null values. + pub(crate) fn try_new(in_array: &ArrayRef) -> Option> { + let in_array = in_array.as_primitive_opt::()?; + let non_null_count = in_array.len() - in_array.null_count(); + if non_null_count != N { + return None; + } + let values: Vec<_> = in_array.iter().flatten().collect(); + // Use default_value() from ArrowPrimitiveType trait instead of Default::default() + let mut arr = [T::default_value(); N]; + arr.copy_from_slice(&values); + Some(Ok(Self { + null_count: in_array.null_count(), + values: arr, + })) + } + + /// Branchless membership check using OR-chain. + #[inline(always)] + fn check(&self, needle: T::Native) -> bool { + self.values + .iter() + .fold(false, |acc, &v| acc | (v == needle)) + } + + /// Check membership using a raw values slice (zero-copy path for type reinterpretation). + #[inline] + pub(crate) fn contains_slice( + &self, + values: &[T::Native], + nulls: Option<&arrow::buffer::NullBuffer>, + negated: bool, + ) -> BooleanArray { + build_in_list_result(values.len(), nulls, self.null_count > 0, negated, |i| { + self.check(unsafe { *values.get_unchecked(i) }) + }) + } +} + +impl StaticFilter for BranchlessFilter +where + T::Native: Copy + PartialEq + Send + Sync, +{ + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + let v = v.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("Failed to downcast array to primitive type") + })?; + let input_values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + // SAFETY: i is in bounds since we iterate 0..v.len() + #[inline(always)] + |i| self.check(unsafe { *input_values.get_unchecked(i) }), + )) + } +} + +// ============================================================================= +// DIRECT PROBE HASH FILTER (O(1) lookup with open addressing) +// ============================================================================= + +/// Load factor inverse for DirectProbeFilter hash table. +/// A value of 4 means 25% load factor (table is 4x the number of elements). +const LOAD_FACTOR_INVERSE: usize = 4; + +/// Minimum table size for DirectProbeFilter. +/// Ensures reasonable performance even for very small IN lists. +const MIN_TABLE_SIZE: usize = 16; + +/// Golden ratio constant for 32-bit hash mixing. +/// Derived from (2^32 / phi) where phi = (1 + sqrt(5)) / 2. +const GOLDEN_RATIO_32: u32 = 0x9e3779b9; + +/// Golden ratio constant for 64-bit hash mixing. +/// Derived from (2^64 / phi) where phi = (1 + sqrt(5)) / 2. +const GOLDEN_RATIO_64: u64 = 0x9e3779b97f4a7c15; + +/// Secondary mixing constant for 128-bit hashing (from SplitMix64). +/// Using a different constant for hi/lo avoids collisions when lo = hi * C. +const SPLITMIX_CONSTANT: u64 = 0xbf58476d1ce4e5b9; + +/// Fast hash filter using open addressing with linear probing. +/// +/// Uses a power-of-2 sized hash table for O(1) average-case lookups. +/// Optimized for the IN list use case with: +/// - Simple/fast hash function (golden ratio multiply + xor-shift) +/// - 25% load factor for minimal collisions +/// - Direct array storage for cache-friendly access +pub(crate) struct DirectProbeFilter +where + T::Native: DirectProbeHashable, +{ + null_count: usize, + /// Hash table with open addressing. None = empty slot, Some(v) = value present + table: Box<[Option]>, + /// Mask for slot index (table.len() - 1, always power of 2 minus 1) + mask: usize, +} + +/// Trait for types that can be hashed for the direct probe filter. +/// +/// Requires `Hash + Eq` for deduplication via `HashSet`, even though we use +/// a custom `probe_hash()` for the actual hash table lookups. +pub(crate) trait DirectProbeHashable: + Copy + PartialEq + std::hash::Hash + Eq +{ + fn probe_hash(self) -> usize; +} + +// Simple but fast hash - golden ratio multiply + xor-shift +impl DirectProbeHashable for i32 { + #[inline(always)] + fn probe_hash(self) -> usize { + let x = self as u32; + let x = x.wrapping_mul(GOLDEN_RATIO_32); + (x ^ (x >> 16)) as usize + } +} + +impl DirectProbeHashable for i64 { + #[inline(always)] + fn probe_hash(self) -> usize { + let x = self as u64; + let x = x.wrapping_mul(GOLDEN_RATIO_64); + (x ^ (x >> 32)) as usize + } +} + +impl DirectProbeHashable for u32 { + #[inline(always)] + fn probe_hash(self) -> usize { + (self as i32).probe_hash() + } +} + +impl DirectProbeHashable for u64 { + #[inline(always)] + fn probe_hash(self) -> usize { + (self as i64).probe_hash() + } +} + +impl DirectProbeHashable for i128 { + #[inline(always)] + fn probe_hash(self) -> usize { + // Mix both halves with different constants to avoid collisions when lo = hi * C + let lo = self as u64; + let hi = (self >> 64) as u64; + let x = lo.wrapping_mul(GOLDEN_RATIO_64) ^ hi.wrapping_mul(SPLITMIX_CONSTANT); + (x ^ (x >> 32)) as usize + } +} + +impl DirectProbeFilter +where + T::Native: DirectProbeHashable, +{ + pub(crate) fn try_new(in_array: &ArrayRef) -> Result { + let arr = in_array.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!( + "DirectProbeFilter: expected {} array", + std::any::type_name::() + ) + })?; + + // Collect unique values using HashSet for deduplication + let unique_values: HashSet<_> = arr.iter().flatten().collect(); + + Ok(Self::from_values_inner( + unique_values.into_iter(), + arr.null_count(), + )) + } + + /// Creates a DirectProbeFilter from an iterator of values. + /// + /// This is useful when building the filter from pre-processed values + /// (e.g., masked views for Utf8View). + pub(crate) fn from_values(values: impl Iterator) -> Self { + // Collect into HashSet for deduplication + let unique_values: HashSet<_> = values.collect(); + Self::from_values_inner(unique_values.into_iter(), 0) + } + + /// Internal constructor from deduplicated values + fn from_values_inner( + unique_values: impl Iterator, + null_count: usize, + ) -> Self { + let unique_values: Vec<_> = unique_values.collect(); + + // Size table to ~25% load factor for fewer collisions + let n = unique_values.len().max(1); + let table_size = (n * LOAD_FACTOR_INVERSE) + .next_power_of_two() + .max(MIN_TABLE_SIZE); + let mask = table_size - 1; + + let mut table: Box<[Option]> = + vec![None; table_size].into_boxed_slice(); + + // Insert all values using linear probing + for v in unique_values { + let mut slot = v.probe_hash() & mask; + loop { + if table[slot].is_none() { + table[slot] = Some(v); + break; + } + slot = (slot + 1) & mask; + } + } + + Self { + null_count, + table, + mask, + } + } + + /// O(1) single-value lookup with linear probing. + /// + /// Returns true if the value is in the set. + #[inline(always)] + pub(crate) fn contains_single(&self, needle: T::Native) -> bool { + let mut slot = needle.probe_hash() & self.mask; + loop { + // SAFETY: `slot` is always < table.len() because: + // - `slot = hash & mask` where `mask = table.len() - 1` + // - table size is always a power of 2 + // - `(slot + 1) & mask` wraps around within bounds + match unsafe { self.table.get_unchecked(slot) } { + None => return false, + Some(v) if *v == needle => return true, + _ => slot = (slot + 1) & self.mask, + } + } + } + + /// Check membership using a raw values slice + #[inline] + pub(crate) fn contains_slice( + &self, + input: &[T::Native], + nulls: Option<&arrow::buffer::NullBuffer>, + negated: bool, + ) -> BooleanArray { + build_in_list_result(input.len(), nulls, self.null_count > 0, negated, |i| { + // SAFETY: i is in bounds since we iterate 0..input.len() + self.contains_single(unsafe { *input.get_unchecked(i) }) + }) + } +} + +impl StaticFilter for DirectProbeFilter +where + T: ArrowPrimitiveType + 'static, + T::Native: DirectProbeHashable + Send + Sync + 'static, +{ + #[inline] + fn null_count(&self) -> usize { + self.null_count + } + + #[inline] + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + // Use raw buffer access for better optimization + let data = v.to_data(); + let values: &[T::Native] = data.buffer::(0); + Ok(self.contains_slice(values, v.nulls(), negated)) + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list/result.rs b/datafusion/physical-expr/src/expressions/in_list/result.rs new file mode 100644 index 0000000000000..787d5dc67e16a --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/result.rs @@ -0,0 +1,243 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Result building helpers for InList operations +//! +//! This module provides unified logic for building BooleanArray results +//! from IN list membership tests, handling null propagation correctly +//! according to SQL three-valued logic. + +use arrow::array::BooleanArray; +use arrow::buffer::{BooleanBuffer, NullBuffer}; + +// ============================================================================= +// RESULT BUILDER FOR IN LIST OPERATIONS +// ============================================================================= +// +// Truth table for (needle_nulls, haystack_has_nulls, negated): +// (Some, true, false) → values: valid & contains, nulls: valid & contains +// (None, true, false) → values: contains, nulls: contains +// (Some, true, true) → values: valid ^ (valid & contains), nulls: valid & contains +// (None, true, true) → values: !contains, nulls: contains +// (Some, false, false) → values: valid & contains, nulls: valid +// (Some, false, true) → values: valid & !contains, nulls: valid +// (None, false, false) → values: contains, nulls: none +// (None, false, true) → values: !contains, nulls: none + +/// Builds a BooleanArray result for IN list operations (optimized for cheap contains). +/// +/// This function handles the complex null propagation logic for SQL IN lists: +/// - If the needle value is null, the result is null +/// - If the needle is not in the set AND the haystack has nulls, the result is null +/// - Otherwise, the result is true/false based on membership and negation +/// +/// This version computes contains for ALL positions (including nulls), then applies +/// null masking via bitmap operations. This is optimal for cheap contains checks +/// (like DirectProbeFilter) where the branch overhead exceeds the check cost. +/// +/// For expensive contains checks (like ByteViewMaskedFilter with string comparison), +/// use `build_in_list_result_with_null_shortcircuit` instead. +#[inline] +pub(crate) fn build_in_list_result( + len: usize, + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains: C, +) -> BooleanArray +where + C: FnMut(usize) -> bool, +{ + // Always compute the contains buffer without checking nulls in the loop. + // The null check inside the loop hurts vectorization and branch prediction. + // Nulls are handled by build_result_from_contains using bitmap operations. + let contains_buf = BooleanBuffer::collect_bool(len, contains); + build_result_from_contains(needle_nulls, haystack_has_nulls, negated, contains_buf) +} + +/// Builds a BooleanArray result with null short-circuit (optimized for expensive contains). +/// +/// Unlike `build_in_list_result`, this version checks nulls INSIDE the loop and +/// skips the contains check for null positions. This is optimal for expensive +/// contains checks (like ByteViewMaskedFilter with hash lookup + string comparison) where +/// skipping lookups outweighs the branch overhead. +/// +/// The shortcircuit is only applied when `needle_null_count > 0` - if there are +/// no actual nulls, we avoid the branch overhead entirely. +/// +/// Use this for: ByteViewMaskedFilter, Utf8TwoStageFilter (string/binary types) +/// Use `build_in_list_result` for: DirectProbeFilter, BranchlessFilter (primitive types) +#[inline] +pub(crate) fn build_in_list_result_with_null_shortcircuit( + len: usize, + needle_nulls: Option<&NullBuffer>, + needle_null_count: usize, + haystack_has_nulls: bool, + negated: bool, + mut contains: C, +) -> BooleanArray +where + C: FnMut(usize) -> bool, +{ + // When null_count=0, treat as no validity buffer to avoid extra work. + // The validity buffer might exist but have all bits set to true. + let effective_nulls = needle_nulls.filter(|_| needle_null_count > 0); + + match effective_nulls { + Some(nulls) => { + // Has nulls: check validity inside loop to skip expensive contains() + let contains_buf = + BooleanBuffer::collect_bool(len, |i| nulls.is_valid(i) && contains(i)); + build_result_from_contains_premasked( + Some(nulls), + haystack_has_nulls, + negated, + contains_buf, + ) + } + None => { + // No nulls: compute contains for all positions without branch overhead + let contains_buf = BooleanBuffer::collect_bool(len, contains); + // Use premasked path since contains_buf is "trivially premasked" (no nulls to mask) + build_result_from_contains_premasked( + None, + haystack_has_nulls, + negated, + contains_buf, + ) + } + } +} + +/// Builds result from a contains buffer that was pre-masked at null positions. +/// +/// This is used by `build_in_list_result_with_null_shortcircuit` where the +/// contains buffer already has `false` at null positions due to the short-circuit. +/// +/// Since contains_buf is pre-masked (false at null positions), we can simplify: +/// - `valid & contains_buf` = `contains_buf` (already 0 where valid is 0) +/// - XOR can replace AND+NOT for the negated case: `valid ^ contains = valid & !contains` +#[inline] +fn build_result_from_contains_premasked( + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains_buf: BooleanBuffer, +) -> BooleanArray { + match (needle_nulls, haystack_has_nulls, negated) { + // Haystack has nulls: result is null unless value is found + (_, true, false) => { + // contains_buf is already masked (false at null positions) + BooleanArray::new(contains_buf.clone(), Some(NullBuffer::new(contains_buf))) + } + (Some(v), true, true) => { + // NOT IN with nulls: true if valid and not found, null if found or needle null + // XOR: valid ^ contains = 1 iff valid=1 and contains=0 (not found) + BooleanArray::new( + v.inner() ^ &contains_buf, + Some(NullBuffer::new(contains_buf)), + ) + } + (None, true, true) => { + BooleanArray::new(!&contains_buf, Some(NullBuffer::new(contains_buf))) + } + // Haystack has no nulls: result validity follows needle validity + (Some(v), false, false) => { + // contains_buf is already masked, just use needle validity for nulls + BooleanArray::new(contains_buf, Some(v.clone())) + } + (Some(v), false, true) => { + // Need AND because !contains_buf is 1 at null positions + BooleanArray::new(v.inner() & &(!&contains_buf), Some(v.clone())) + } + (None, false, false) => BooleanArray::new(contains_buf, None), + (None, false, true) => BooleanArray::new(!&contains_buf, None), + } +} + +/// Builds a BooleanArray result from a pre-computed contains buffer. +/// +/// This version does NOT assume contains_buf is pre-masked at null positions. +/// It handles nulls using bitmap operations which are more vectorization-friendly. +#[inline] +pub(crate) fn build_result_from_contains( + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains_buf: BooleanBuffer, +) -> BooleanArray { + match (needle_nulls, haystack_has_nulls, negated) { + // Haystack has nulls: result is null unless value is found + (Some(v), true, false) => { + // values: valid & contains, nulls: valid & contains + // Result is valid (not null) only when needle is valid AND found in haystack + let values = v.inner() & &contains_buf; + BooleanArray::new(values.clone(), Some(NullBuffer::new(values))) + } + (None, true, false) => { + BooleanArray::new(contains_buf.clone(), Some(NullBuffer::new(contains_buf))) + } + (Some(v), true, true) => { + // NOT IN with nulls: true if valid and not found, null if found or needle null + // values: valid & !contains, nulls: valid & contains + // Result is valid only when needle is valid AND found (because NOT IN with + // haystack nulls returns NULL when value isn't definitively excluded) + let valid = v.inner(); + let values = valid & &(!&contains_buf); + let nulls = valid & &contains_buf; + BooleanArray::new(values, Some(NullBuffer::new(nulls))) + } + (None, true, true) => { + BooleanArray::new(!&contains_buf, Some(NullBuffer::new(contains_buf))) + } + // Haystack has no nulls: result validity follows needle validity + (Some(v), false, false) => { + // values: valid & contains (mask out nulls), nulls: valid + BooleanArray::new(v.inner() & &contains_buf, Some(v.clone())) + } + (Some(v), false, true) => { + // values: valid & !contains, nulls: valid + BooleanArray::new(v.inner() & &(!&contains_buf), Some(v.clone())) + } + (None, false, false) => BooleanArray::new(contains_buf, None), + (None, false, true) => BooleanArray::new(!&contains_buf, None), + } +} + +// ============================================================================= +// DICTIONARY ARRAY HANDLING +// ============================================================================= + +/// Macro to handle dictionary arrays in StaticFilter::contains implementations. +/// +/// This macro extracts the dictionary values, performs the contains check on +/// the values array, and then uses `take` to map the results back to the +/// dictionary keys. +macro_rules! handle_dictionary { + ($self:ident, $v:ident, $negated:ident) => { + arrow::array::downcast_dictionary_array! { + $v => { + let values_contains = $self.contains($v.values().as_ref(), $negated)?; + let result = arrow::compute::take(&values_contains, $v.keys(), None)?; + return Ok(arrow::array::downcast_array(result.as_ref())) + } + _ => {} + } + }; +} + +pub(crate) use handle_dictionary; diff --git a/datafusion/physical-expr/src/expressions/in_list/static_filter.rs b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs new file mode 100644 index 0000000000000..9dbc00d35125c --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Core trait for InList static filters + +use arrow::array::{Array, BooleanArray}; +use datafusion_common::Result; + +/// Trait for InList static filters. +/// +/// Static filters are pre-computed lookup structures that enable efficient +/// membership testing for IN list expressions. Different implementations +/// optimize for different data types: +/// +/// - [`super::primitive_filter::BitmapFilter`]: O(1) bit test for u8/u16 +/// - [`super::primitive_filter::BranchlessFilter`]: Unrolled OR-chain for small lists +/// - [`super::primitive_filter::DirectProbeFilter`]: O(1) hash lookups for larger primitive types +/// - [`super::transform::Utf8TwoStageFilter`]: Two-stage filter for Utf8/LargeUtf8 +/// - [`super::nested_filter::NestedTypeFilter`]: Dynamic comparator for complex types +pub(crate) trait StaticFilter { + /// Returns the number of null values in the filter's haystack. + fn null_count(&self) -> usize; + + /// Checks if values in `v` are contained in the filter. + /// + /// Returns a `BooleanArray` with the same length as `v`, where each element + /// indicates whether the corresponding value is in the filter (or NOT in, + /// if `negated` is true). + /// + /// Follows SQL three-valued logic: + /// - If the needle value is null, the result is null + /// - If the needle is not found AND the haystack contains nulls, the result is null + /// - Otherwise, the result is true/false based on membership + fn contains(&self, v: &dyn Array, negated: bool) -> Result; +} diff --git a/datafusion/physical-expr/src/expressions/in_list/strategy.rs b/datafusion/physical-expr/src/expressions/in_list/strategy.rs new file mode 100644 index 0000000000000..2abe1fe5044d1 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/strategy.rs @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Filter selection strategy for InList expressions +//! +//! Selects the optimal lookup strategy based on data type and list size: +//! +//! - 1-byte types (Int8/UInt8): bitmap (32 bytes, O(1) bit test) +//! - 2-byte types (Int16/UInt16): bitmap (8 KB, O(1) bit test) +//! - 4-byte types (Int32/Float32): branchless (≤32) or hash (>32) +//! - 8-byte types (Int64/Float64): branchless (≤16) or hash (>16) +//! - 16-byte types (Decimal128): branchless (≤4) or hash (>4) +//! - Utf8View (short strings): branchless (≤4) or hash (>4) +//! - Byte arrays (Utf8, Binary, etc.): ByteArrayFilter / ByteViewFilter +//! - Other types: NestedTypeFilter (fallback for List, Struct, Map, etc.) + +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::datatypes::*; +use datafusion_common::{Result, exec_datafusion_err}; + +use super::nested_filter::NestedTypeFilter; +use super::primitive_filter::*; +use super::result::handle_dictionary; +use super::static_filter::StaticFilter; +use super::transform::{ + make_bitmap_filter, make_branchless_filter, make_byte_view_masked_filter, + make_utf8_two_stage_filter, make_utf8view_branchless_filter, + make_utf8view_hash_filter, utf8_all_short_strings, utf8view_all_short_strings, +}; + +// ============================================================================= +// LOOKUP STRATEGY THRESHOLDS (tuned via microbenchmarks) +// ============================================================================= +// +// Based on minimum batch time (8192 lookups per batch): +// - Int8 (1 byte): BITMAP (32 bytes, always fastest) +// - Int16 (2 bytes): BITMAP (8 KB, always fastest) +// - Int32 (4 bytes): branchless up to 32, then hashset +// - Int64 (8 bytes): branchless up to 16, then hashset +// - Int128 (16 bytes): branchless up to 4, then hashset +// - Byte arrays: ByteArrayFilter / ByteViewFilter +// - Other types: NestedTypeFilter (fallback for List, Struct, Map, etc.) +// +// NOTE: Binary search and linear scan were benchmarked but consistently +// lost to the strategies above at all tested list sizes. + +/// Maximum list size for branchless lookup on 4-byte primitives (Int32, UInt32, Float32). +const BRANCHLESS_MAX_4B: usize = 32; + +/// Maximum list size for branchless lookup on 8-byte primitives (Int64, UInt64, Float64). +const BRANCHLESS_MAX_8B: usize = 16; + +/// Maximum list size for branchless lookup on 16-byte types (Decimal128). +const BRANCHLESS_MAX_16B: usize = 4; + +// ============================================================================= +// FILTER STRATEGY SELECTION +// ============================================================================= + +/// The lookup strategy to use for a given data type and list size. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FilterStrategy { + /// Bitmap filter for u8/u16 - O(1) bit test, always fastest for these types. + Bitmap1B, + Bitmap2B, + /// Branchless OR-chain for small lists. + Branchless, + /// HashSet for larger lists. + Hashed, + /// Generic ArrayStaticFilter fallback. + Generic, +} + +/// Determines the optimal lookup strategy based on data type and list size. +/// +/// For 1-byte and 2-byte types, bitmap is always used (benchmarks show it's +/// faster than both branchless and hashed at all list sizes). +/// For larger types, cutoffs are tuned per byte-width. +fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy { + match dt.primitive_width() { + Some(1) => FilterStrategy::Bitmap1B, + Some(2) => FilterStrategy::Bitmap2B, + Some(4) => { + if len <= BRANCHLESS_MAX_4B { + FilterStrategy::Branchless + } else { + FilterStrategy::Hashed + } + } + Some(8) => { + if len <= BRANCHLESS_MAX_8B { + FilterStrategy::Branchless + } else { + FilterStrategy::Hashed + } + } + Some(16) => { + if len <= BRANCHLESS_MAX_16B { + FilterStrategy::Branchless + } else { + FilterStrategy::Hashed + } + } + _ => FilterStrategy::Generic, + } +} + +// ============================================================================= +// FILTER INSTANTIATION +// ============================================================================= + +/// Creates the optimal static filter for the given array. +/// +/// This is the main entry point for filter creation. It analyzes the array's +/// data type and size to select the best lookup strategy. +pub(crate) fn instantiate_static_filter( + in_array: ArrayRef, +) -> Result> { + use FilterStrategy::*; + + let len = in_array.len(); + let dt = in_array.data_type(); + + // Special case: Utf8View with short strings can be reinterpreted as i128 + if matches!(dt, DataType::Utf8View) && utf8view_all_short_strings(in_array.as_ref()) { + return if len <= BRANCHLESS_MAX_16B { + make_utf8view_branchless_filter(&in_array) + } else { + make_utf8view_hash_filter(&in_array) + }; + } + + let strategy = select_strategy(dt, len); + + match (dt, strategy) { + // Bitmap filters for 1-byte and 2-byte types + (_, Bitmap1B) => make_bitmap_filter::(&in_array), + (_, Bitmap2B) => make_bitmap_filter::(&in_array), + + // Branchless filters for small lists of primitives + (_, Branchless) => dispatch_branchless(&in_array).ok_or_else(|| { + exec_datafusion_err!( + "Branchless strategy selected but no filter for {:?}", + dt + ) + })?, + + // Hash filters for larger lists of primitives + (_, Hashed) => dispatch_hashed(&in_array).ok_or_else(|| { + exec_datafusion_err!("Hashed strategy selected but no filter for {:?}", dt) + })?, + + // Utf8/LargeUtf8: Two-stage filter when all IN-list strings are short (≤12 bytes). + // Stage 1 encodes as i128 (length + first 12 bytes) for O(1) rejection. + // When strings are long, the encoding can't definitively match and the + // overhead regresses vs the generic fallback, so we skip it. + (DataType::Utf8 | DataType::LargeUtf8, Generic) + if utf8_all_short_strings(in_array.as_ref()) => + { + make_utf8_two_stage_filter(in_array) + } + + // Binary variants: Use NestedTypeFilter (make_comparator) + (DataType::Binary | DataType::LargeBinary, Generic) => { + Ok(Arc::new(NestedTypeFilter::try_new(in_array)?)) + } + + // Byte view filters (Utf8View, BinaryView) + // Both use two-stage filter: masked view pre-check + full verification + (DataType::Utf8View, Generic) => { + make_byte_view_masked_filter::(in_array) + } + (DataType::BinaryView, Generic) => { + make_byte_view_masked_filter::(in_array) + } + + // Fallback for nested/complex types (List, Struct, Map, Union, etc.) + (_, Generic) => Ok(Arc::new(NestedTypeFilter::try_new(in_array)?)), + } +} + +// ============================================================================= +// TYPE DISPATCH +// ============================================================================= + +fn dispatch_branchless( + arr: &ArrayRef, +) -> Option>> { + // Dispatch to width-specific branchless filter. + // Each width has its own max size: 4B→32, 8B→16, 16B→4 + match arr.data_type().primitive_width() { + Some(4) => Some(make_branchless_filter::(arr, 4)), + Some(8) => Some(make_branchless_filter::(arr, 8)), + Some(16) => Some(make_branchless_filter::(arr, 16)), + _ => None, + } +} + +fn dispatch_hashed( + arr: &ArrayRef, +) -> Option>> { + // Use DirectProbeFilter for fast hash table lookups + macro_rules! direct_probe_filter { + ($T:ty) => { + return Some( + DirectProbeFilter::<$T>::try_new(arr) + .map(|f| Arc::new(f) as Arc), + ) + }; + } + match arr.data_type() { + DataType::Int32 => direct_probe_filter!(Int32Type), + DataType::Int64 => direct_probe_filter!(Int64Type), + DataType::UInt32 => direct_probe_filter!(UInt32Type), + DataType::UInt64 => direct_probe_filter!(UInt64Type), + _ => {} + } + + // For other primitive types, reinterpret bits as appropriate UInt/Int type + match arr.data_type().primitive_width() { + Some(4) => Some(make_direct_probe_filter_reinterpreted::(arr)), + Some(8) => Some(make_direct_probe_filter_reinterpreted::(arr)), + Some(16) => Some(make_direct_probe_filter_reinterpreted::( + arr, + )), + // Other widths (1, 2) use Bitmap strategy and never reach here. + // Unknown widths fall through to Generic strategy. + _ => None, + } +} + +/// Creates a DirectProbeFilter with type reinterpretation for Float types +fn make_direct_probe_filter_reinterpreted( + in_array: &ArrayRef, +) -> Result> +where + D: ArrowPrimitiveType + 'static, + D::Native: Send + Sync + DirectProbeHashable + 'static, +{ + use super::transform::reinterpret_any_primitive_to; + + // Fast path: already the right type + if in_array.data_type() == &D::DATA_TYPE { + return Ok(Arc::new(DirectProbeFilter::::try_new(in_array)?)); + } + + // Reinterpret and create filter + let reinterpreted = reinterpret_any_primitive_to::(in_array.as_ref()); + let inner = DirectProbeFilter::::try_new(&reinterpreted)?; + Ok(Arc::new(ReinterpretedDirectProbeFilter { inner })) +} + +/// Wrapper for DirectProbeFilter with type reinterpretation +struct ReinterpretedDirectProbeFilter +where + D::Native: DirectProbeHashable, +{ + inner: DirectProbeFilter, +} + +impl StaticFilter for ReinterpretedDirectProbeFilter +where + D: ArrowPrimitiveType + 'static, + D::Native: Send + Sync + DirectProbeHashable + 'static, +{ + #[inline] + fn null_count(&self) -> usize { + self.inner.null_count() + } + + #[inline] + fn contains( + &self, + v: &dyn arrow::array::Array, + negated: bool, + ) -> Result { + handle_dictionary!(self, v, negated); + // Reinterpret needle array to destination type and use inner filter's raw slice path + let data = v.to_data(); + let values: &[D::Native] = data.buffer::(0); + Ok(self.inner.contains_slice(values, v.nulls(), negated)) + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list/transform.rs b/datafusion/physical-expr/src/expressions/in_list/transform.rs new file mode 100644 index 0000000000000..e50b642bd432d --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/transform.rs @@ -0,0 +1,779 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Type transformation utilities for InList filters +//! +//! This module provides type reinterpretation for optimizing filter dispatch. +//! For equality comparison, only the bit pattern matters, so we can: +//! - Reinterpret signed integers as unsigned (Int32 → UInt32) +//! - Reinterpret floats as unsigned integers (Float64 → UInt64) +//! +//! This allows using a single filter implementation (e.g., for UInt64) to handle +//! multiple types (Int64, Float64, Timestamp, Duration) that share the same +//! byte width, reducing code duplication. + +use std::marker::PhantomData; +use std::sync::Arc; + +use ahash::RandomState; +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray, PrimitiveArray}; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::{ArrowPrimitiveType, ByteViewType, Decimal128Type}; +use arrow::util::bit_iterator::BitIndexIterator; +use datafusion_common::Result; +use datafusion_common::hash_utils::with_hashes; +use hashbrown::HashTable; + +use super::primitive_filter::{ + BitmapFilter, BitmapFilterConfig, BranchlessFilter, DirectProbeFilter, +}; +use super::result::{build_in_list_result_with_null_shortcircuit, handle_dictionary}; +use super::static_filter::StaticFilter; + +/// Maximum length for inline strings (≤12 bytes can be stored in 16-byte view/encoding). +/// Used by both Utf8View short string optimization and Utf8 two-stage filter. +pub(crate) const INLINE_STRING_LEN: usize = 12; + +// ============================================================================= +// REINTERPRETING FILTERS (zero-copy type conversion) +// ============================================================================= + +/// Reinterpreting filter for bitmap lookups (u8/u16). +struct ReinterpretedBitmap { + inner: BitmapFilter, +} + +impl StaticFilter for ReinterpretedBitmap { + fn null_count(&self) -> usize { + self.inner.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + let data = v.to_data(); + let values: &[C::Native] = data.buffer::(0); + + Ok(self.inner.contains_slice(values, data.nulls(), negated)) + } +} + +/// Reinterpreting filter for branchless lookups. +struct ReinterpretedBranchless { + inner: BranchlessFilter, +} + +impl StaticFilter for ReinterpretedBranchless +where + T: ArrowPrimitiveType + 'static, + T::Native: Copy + PartialEq + Send + Sync + 'static, +{ + fn null_count(&self) -> usize { + self.inner.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + let data = v.to_data(); + let values: &[T::Native] = data.buffer::(0); + + Ok(self.inner.contains_slice(values, data.nulls(), negated)) + } +} + +/// Hash filter for Utf8View short strings (≤12 bytes). +/// +/// Reinterprets the views buffer directly as i128 slice. +struct Utf8ViewHashFilter { + inner: DirectProbeFilter, +} + +impl StaticFilter for Utf8ViewHashFilter { + fn null_count(&self) -> usize { + self.inner.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + // Zero-copy: reinterpret views buffer directly as i128 slice + let sv = v.as_string_view(); + let values: &[i128] = sv.views().inner().typed_data(); + + Ok(self.inner.contains_slice(values, sv.nulls(), negated)) + } +} + +/// Reinterprets any primitive-like array as the target primitive type T by extracting +/// the underlying buffer. +/// +/// This is a zero-copy operation that works for all primitive types (Int*, UInt*, Float*, +/// Timestamp*, Date*, Duration*, etc.) by directly accessing the underlying buffer, +/// ignoring any metadata like timezones or precision/scale. +#[inline] +pub(crate) fn reinterpret_any_primitive_to( + array: &dyn Array, +) -> ArrayRef { + let values = array.to_data().buffers()[0].clone(); + let buffer: ScalarBuffer = values.into(); + Arc::new(PrimitiveArray::::new(buffer, array.nulls().cloned())) +} + +/// Creates a bitmap filter for u8/u16 types, reinterpreting if needed. +pub(crate) fn make_bitmap_filter( + in_array: &ArrayRef, +) -> Result> +where + C: BitmapFilterConfig, +{ + if in_array.data_type() == &C::ArrowType::DATA_TYPE { + return Ok(Arc::new(BitmapFilter::::try_new(in_array)?)); + } + + let reinterpreted = reinterpret_any_primitive_to::(in_array.as_ref()); + let inner = BitmapFilter::::try_new(&reinterpreted)?; + Ok(Arc::new(ReinterpretedBitmap { inner })) +} + +// ============================================================================= +// BRANCHLESS FILTER CREATION (const generic dispatch) +// ============================================================================= + +/// Creates a branchless filter for primitive types. +/// +/// Dispatches based on byte width and element count: +/// - 4-byte types (Int32, Float32, etc.): supports 0-32 elements +/// - 8-byte types (Int64, Float64, Timestamp, etc.): supports 0-16 elements +/// - 16-byte types (Decimal128): supports 0-4 elements +pub(crate) fn make_branchless_filter( + in_array: &ArrayRef, + width: usize, +) -> Result> +where + D: ArrowPrimitiveType + 'static, + D::Native: Copy + PartialEq + Send + Sync + 'static, +{ + let is_native = in_array.data_type() == &D::DATA_TYPE; + let arr = if is_native { + Arc::clone(in_array) + } else { + reinterpret_any_primitive_to::(in_array.as_ref()) + }; + let n = arr.len() - arr.null_count(); + + // Helper to create the filter for a known size N + #[inline] + fn create( + arr: &ArrayRef, + is_native: bool, + ) -> Result> + where + D::Native: Copy + PartialEq + Send + Sync + 'static, + { + let inner = BranchlessFilter::::try_new(arr) + .expect("size verified") + .expect("type verified"); + if is_native { + Ok(Arc::new(inner)) + } else { + Ok(Arc::new(ReinterpretedBranchless { inner })) + } + } + + // Match on (width, count) - shared sizes use or-patterns to avoid duplication + match (width, n) { + // All widths: 0-4 + (4 | 8 | 16, 0) => create::(&arr, is_native), + (4 | 8 | 16, 1) => create::(&arr, is_native), + (4 | 8 | 16, 2) => create::(&arr, is_native), + (4 | 8 | 16, 3) => create::(&arr, is_native), + (4 | 8 | 16, 4) => create::(&arr, is_native), + // 4-byte and 8-byte: 5-16 + (4 | 8, 5) => create::(&arr, is_native), + (4 | 8, 6) => create::(&arr, is_native), + (4 | 8, 7) => create::(&arr, is_native), + (4 | 8, 8) => create::(&arr, is_native), + (4 | 8, 9) => create::(&arr, is_native), + (4 | 8, 10) => create::(&arr, is_native), + (4 | 8, 11) => create::(&arr, is_native), + (4 | 8, 12) => create::(&arr, is_native), + (4 | 8, 13) => create::(&arr, is_native), + (4 | 8, 14) => create::(&arr, is_native), + (4 | 8, 15) => create::(&arr, is_native), + (4 | 8, 16) => create::(&arr, is_native), + // 4-byte only: 17-32 + (4, 17) => create::(&arr, is_native), + (4, 18) => create::(&arr, is_native), + (4, 19) => create::(&arr, is_native), + (4, 20) => create::(&arr, is_native), + (4, 21) => create::(&arr, is_native), + (4, 22) => create::(&arr, is_native), + (4, 23) => create::(&arr, is_native), + (4, 24) => create::(&arr, is_native), + (4, 25) => create::(&arr, is_native), + (4, 26) => create::(&arr, is_native), + (4, 27) => create::(&arr, is_native), + (4, 28) => create::(&arr, is_native), + (4, 29) => create::(&arr, is_native), + (4, 30) => create::(&arr, is_native), + (4, 31) => create::(&arr, is_native), + (4, 32) => create::(&arr, is_native), + // Error cases + (4, n) => datafusion_common::exec_err!( + "Branchless filter for 4-byte types supports 0-32 elements, got {n}" + ), + (8, n) => datafusion_common::exec_err!( + "Branchless filter for 8-byte types supports 0-16 elements, got {n}" + ), + (16, n) => datafusion_common::exec_err!( + "Branchless filter for 16-byte types supports 0-4 elements, got {n}" + ), + (w, _) => datafusion_common::exec_err!( + "Branchless filter not supported for {w}-byte types" + ), + } +} + +// ============================================================================= +// UTF8VIEW REINTERPRETATION (short strings ≤12 bytes → Decimal128) +// ============================================================================= + +// NOTE: Optimizations below assume Little Endian layout (DataFusion standard). + +/// Helper to extract the length from a Utf8View u128/i128 view. +#[inline(always)] +fn view_len(view: i128) -> u32 { + view as u32 +} + +/// Checks if all strings in a Utf8View array are short enough to be inline. +/// +/// In Utf8View, strings ≤12 bytes are stored inline in the 16-byte view struct. +/// These can be reinterpreted as i128 for fast equality comparison. +#[inline] +pub(crate) fn utf8view_all_short_strings(array: &dyn Array) -> bool { + let sv = array.as_string_view(); + sv.views().iter().enumerate().all(|(i, &view)| { + !sv.is_valid(i) || view_len(view as i128) as usize <= INLINE_STRING_LEN + }) +} + +/// Reinterprets a Utf8View array as Decimal128 by treating the view bytes as i128. +#[inline] +fn reinterpret_utf8view_as_decimal128(array: &dyn Array) -> ArrayRef { + let sv = array.as_string_view(); + let buffer: ScalarBuffer = sv.views().inner().clone().into(); + Arc::new(PrimitiveArray::::new( + buffer, + sv.nulls().cloned(), + )) +} + +/// Creates a hash filter for Utf8View arrays with short strings. +pub(crate) fn make_utf8view_hash_filter( + in_array: &ArrayRef, +) -> Result> { + let reinterpreted = reinterpret_utf8view_as_decimal128(in_array.as_ref()); + let inner = DirectProbeFilter::::try_new(&reinterpreted)?; + Ok(Arc::new(Utf8ViewHashFilter { inner })) +} + +/// Creates a branchless filter for Utf8View arrays with short strings. +pub(crate) fn make_utf8view_branchless_filter( + in_array: &ArrayRef, +) -> Result> { + let reinterpreted = reinterpret_utf8view_as_decimal128(in_array.as_ref()); + + macro_rules! try_branchless { + ($($n:literal),*) => { + $(if let Some(Ok(inner)) = BranchlessFilter::::try_new(&reinterpreted) { + return Ok(Arc::new(Utf8ViewBranchless { inner })); + })* + }; + } + try_branchless!(0, 1, 2, 3, 4); + + datafusion_common::exec_err!( + "Utf8View branchless filter only supports 0-4 elements, got {}", + in_array.len() - in_array.null_count() + ) +} + +/// Branchless filter for Utf8View short strings (≤12 bytes). +struct Utf8ViewBranchless { + inner: BranchlessFilter, +} + +impl StaticFilter for Utf8ViewBranchless { + fn null_count(&self) -> usize { + self.inner.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + let sv = v.as_string_view(); + let values: &[i128] = sv.views().inner().typed_data(); + + Ok(self.inner.contains_slice(values, sv.nulls(), negated)) + } +} + +// ============================================================================= +// UTF8VIEW TWO-STAGE FILTER (masked view pre-check + full verification) +// ============================================================================= + +/// Mask to extract len + prefix from a Utf8View view (zeroes out buffer_index and offset). +/// +/// View layout (16 bytes, Little Endian): +/// - Bytes 0-3 (low): length (u32) +/// - Bytes 4-7: prefix (long strings) or inline data bytes 0-3 (short strings) +/// - Bytes 8-11: buffer_index (long) or inline data bytes 4-7 (short) +/// - Bytes 12-15 (high): offset (long) or inline data bytes 8-11 (short) +/// +/// For long strings (>12 bytes), buffer_index and offset are array-specific, +/// so we mask them out, keeping only len + prefix for comparison. +const VIEW_MASK_LONG: i128 = (1_i128 << 64) - 1; // Keep low 64 bits + +/// Computes the masked view for comparison. +/// +/// - Short strings (≤12 bytes): returns full view (all data is inline) +/// - Long strings (>12 bytes): returns only len + prefix (masks out buffer_index/offset) +#[inline(always)] +fn masked_view(view: i128) -> i128 { + let len = view_len(view) as usize; + + if len <= INLINE_STRING_LEN { + view // Short string: all 16 bytes are meaningful data + } else { + view & VIEW_MASK_LONG // Long string: keep only len + prefix + } +} + +/// Two-stage filter for ByteView arrays (Utf8View, BinaryView) with mixed lengths. +/// +/// Stage 1: Quick rejection using masked views (len + prefix as i128) +/// - Non-matches rejected without any hashing using DirectProbeFilter +/// - Short value matches (≤12 bytes) accepted immediately +/// +/// Stage 2: Full verification for long value matches +/// - Only reached when masked view matches AND value is long (>12 bytes) +/// - Uses HashTable lookup with indices into haystack array +pub(crate) struct ByteViewMaskedFilter { + /// The haystack array containing values to match against. + in_array: ArrayRef, + /// DirectProbeFilter for O(1) masked view quick rejection (faster than HashSet) + masked_view_filter: DirectProbeFilter, + /// HashTable storing indices of long strings for Stage 2 verification + long_value_table: HashTable, + /// Random state for consistent hashing between haystack and needles + state: RandomState, + _phantom: PhantomData, +} + +impl ByteViewMaskedFilter +where + T::Native: PartialEq, +{ + pub(crate) fn try_new(in_array: ArrayRef) -> Result { + let bv = in_array.as_byte_view::(); + let views: &[i128] = bv.views().inner().typed_data(); + + let mut masked_views = Vec::new(); + let state = RandomState::new(); + let mut long_value_table = HashTable::new(); + + // Build hash table for long strings using batch hashing + with_hashes([in_array.as_ref()], &state, |hashes| { + let mut process_idx = |idx: usize| { + let view = views[idx]; + masked_views.push(masked_view(view)); + + // For long strings, store index in hash table + let len = view_len(view) as usize; + if len > INLINE_STRING_LEN { + let hash = hashes[idx]; + // SAFETY: idx is valid from iterator + let val = unsafe { bv.value_unchecked(idx) }; + let bytes: &[u8] = val.as_ref(); + + // Only insert if not already present (deduplication) + if long_value_table + .find(hash, |&stored_idx| { + let stored: &[u8] = + unsafe { bv.value_unchecked(stored_idx) }.as_ref(); + stored == bytes + }) + .is_none() + { + long_value_table.insert_unique(hash, idx, |&i| hashes[i]); + } + } + }; + + match bv.nulls() { + Some(nulls) => { + BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) + .for_each(&mut process_idx); + } + None => { + (0..in_array.len()).for_each(&mut process_idx); + } + } + Ok::<_, datafusion_common::DataFusionError>(()) + })?; + + // Build DirectProbeFilter from collected masked views + let masked_view_filter = + DirectProbeFilter::::from_values(masked_views.into_iter()); + + Ok(Self { + in_array, + masked_view_filter, + long_value_table, + state, + _phantom: PhantomData, + }) + } +} + +impl StaticFilter for ByteViewMaskedFilter +where + T::Native: PartialEq, +{ + fn null_count(&self) -> usize { + self.in_array.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + let needle_bv = v.as_byte_view::(); + let needle_views: &[i128] = needle_bv.views().inner().typed_data(); + let needle_null_count = needle_bv.null_count(); + let haystack_has_nulls = self.in_array.null_count() > 0; + let haystack_bv = self.in_array.as_byte_view::(); + + // Single pass with lazy hashing - only hash long values that pass Stage 1 + // Use null shortcircuit: Stage 2 string comparison is expensive, + // so skipping lookups for null positions is worth the branch overhead + Ok(build_in_list_result_with_null_shortcircuit( + v.len(), + needle_bv.nulls(), + needle_null_count, + haystack_has_nulls, + negated, + #[inline(always)] + |i| { + let needle_view = needle_views[i]; + let masked = masked_view(needle_view); + + // Stage 1: Quick rejection via DirectProbeFilter (O(1) lookup) + if !self.masked_view_filter.contains_single(masked) { + return false; + } + + // Masked view found in set + let needle_len = view_len(needle_view) as usize; + + if needle_len <= INLINE_STRING_LEN { + // Short value: masked view = full view, true match + return true; + } + + // Stage 2: Long value - hash lazily and lookup in hash table + // SAFETY: i is in bounds, closure only called for valid positions + let needle_val = unsafe { needle_bv.value_unchecked(i) }; + let needle_bytes: &[u8] = needle_val.as_ref(); + let hash = self.state.hash_one(needle_bytes); + + self.long_value_table + .find(hash, |&idx| { + let haystack_val: &[u8] = + unsafe { haystack_bv.value_unchecked(idx) }.as_ref(); + haystack_val == needle_bytes + }) + .is_some() + }, + )) + } +} + +/// Creates a two-stage filter for ByteView arrays (Utf8View, BinaryView). +pub(crate) fn make_byte_view_masked_filter( + in_array: ArrayRef, +) -> Result> +where + T::Native: PartialEq, +{ + Ok(Arc::new(ByteViewMaskedFilter::::try_new(in_array)?)) +} + +// ============================================================================= +// UTF8 TWO-STAGE FILTER (length+prefix pre-check + full verification) +// ============================================================================= +// +// Similar to ByteViewMaskedFilter but for regular Utf8/LargeUtf8 arrays. +// Encodes strings as i128 with length + prefix for quick rejection. +// +// Encoding (Little Endian): +// - Bytes 0-3: length (u32) +// - Bytes 4-15: data (12 bytes) +// +// This naturally distinguishes short from long strings via the length field. +// For short strings (≤12 bytes), the i128 contains all data → match is definitive. +// For long strings (>12 bytes), a match requires full string comparison. + +/// Encodes a string as i128 with length + prefix. +/// Format: [len:u32][data:12 bytes] (Little Endian) +#[inline(always)] +fn encode_string_as_i128(s: &[u8]) -> i128 { + let len = s.len(); + + // Optimization: Construct the i128 directly using arithmetic and pointer copy + // to avoid Store-to-Load Forwarding (STLF) stalls on x64 and minimize LSU pressure on ARM. + // + // The layout in memory must match Utf8View: [4 bytes len][12 bytes data] + let mut val: u128 = len as u128; // Length in bytes 0-3 + + // Safety: writing to the remaining bytes of an initialized u128. + // We use a pointer copy for the string data as it is variable length (0-12 bytes). + unsafe { + let dst = (&mut val as *mut u128 as *mut u8).add(4); + std::ptr::copy_nonoverlapping(s.as_ptr(), dst, len.min(INLINE_STRING_LEN)); + } + + val as i128 +} + +/// Two-stage filter for Utf8/LargeUtf8 arrays. +/// +/// Stage 1: Quick rejection using length+prefix as i128 +/// - Non-matches rejected via O(1) DirectProbeFilter lookup +/// - Short string matches (≤12 bytes) accepted immediately +/// +/// Stage 2: Full verification for long string matches +/// - Only reached when encoded i128 matches AND string length >12 bytes +/// - Uses HashTable with full string comparison +pub(crate) struct Utf8TwoStageFilter { + /// The haystack array containing values to match against + in_array: ArrayRef, + /// DirectProbeFilter for O(1) encoded i128 quick rejection + encoded_filter: DirectProbeFilter, + /// HashTable storing indices of long strings (>12 bytes) for Stage 2 + long_string_table: HashTable, + /// Random state for consistent hashing + state: RandomState, + /// Whether all haystack strings are short (≤12 bytes) - enables fast path + all_short: bool, + _phantom: PhantomData, +} + +impl Utf8TwoStageFilter { + pub(crate) fn try_new(in_array: ArrayRef) -> Result { + use arrow::array::GenericStringArray; + + let arr = in_array + .as_any() + .downcast_ref::>() + .expect("Utf8TwoStageFilter requires GenericStringArray"); + + let len = arr.len(); + let mut encoded_values = Vec::with_capacity(len); + let state = RandomState::new(); + let mut long_string_table = HashTable::new(); + let mut all_short = true; + + // Build encoded values and long string table + for i in 0..len { + if arr.is_null(i) { + encoded_values.push(0); + continue; + } + + let s = arr.value(i); + let bytes = s.as_bytes(); + encoded_values.push(encode_string_as_i128(bytes)); + + if bytes.len() > INLINE_STRING_LEN { + all_short = false; + // Add to long string table for Stage 2 verification (with deduplication) + let hash = state.hash_one(bytes); + if long_string_table + .find(hash, |&stored_idx| { + arr.value(stored_idx).as_bytes() == bytes + }) + .is_none() + { + long_string_table.insert_unique(hash, i, |&idx| { + state.hash_one(arr.value(idx).as_bytes()) + }); + } + } + } + + // Build DirectProbeFilter from encoded values + let nulls = arr + .nulls() + .map(|n| arrow::buffer::NullBuffer::new(n.inner().clone())); + let encoded_array: ArrayRef = Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(encoded_values), + nulls, + )); + let encoded_filter = + DirectProbeFilter::::try_new(&encoded_array)?; + + Ok(Self { + in_array, + encoded_filter, + long_string_table, + state, + all_short, + _phantom: PhantomData, + }) + } +} + +impl StaticFilter for Utf8TwoStageFilter { + fn null_count(&self) -> usize { + self.in_array.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + use arrow::array::GenericStringArray; + + handle_dictionary!(self, v, negated); + + let needle_arr = v + .as_any() + .downcast_ref::>() + .expect("needle array type mismatch in Utf8TwoStageFilter"); + let haystack_arr = self + .in_array + .as_any() + .downcast_ref::>() + .expect("haystack array type mismatch in Utf8TwoStageFilter"); + + let haystack_has_nulls = self.in_array.null_count() > 0; + + if self.all_short { + // Fast path: all haystack strings are short + // Batch-encode all needles and do bulk lookup + let needle_encoded: Vec = (0..needle_arr.len()) + .map(|i| { + if needle_arr.is_null(i) { + 0 + } else { + encode_string_as_i128(needle_arr.value(i).as_bytes()) + } + }) + .collect(); + + // For short haystack, encoded match is definitive for short needles. + // Long needles (>12 bytes) can never match, but their encoded form + // won't match any short haystack encoding (different length field). + return Ok(self.encoded_filter.contains_slice( + &needle_encoded, + needle_arr.nulls(), + negated, + )); + } + + // Two-stage path: haystack has long strings + Ok(super::result::build_in_list_result( + v.len(), + needle_arr.nulls(), + haystack_has_nulls, + negated, + |i| { + // SAFETY: i is in bounds [0, v.len()), guaranteed by build_in_list_result + let needle_bytes = unsafe { needle_arr.value_unchecked(i) }.as_bytes(); + let encoded = encode_string_as_i128(needle_bytes); + + // Stage 1: Quick rejection via encoded i128 + if !self.encoded_filter.contains_single(encoded) { + return false; + } + + // Encoded match found + let needle_len = needle_bytes.len(); + if needle_len <= INLINE_STRING_LEN { + // Short needle: encoded contains all data, match is definitive + // (If haystack had a long string with same prefix, its length + // field would differ, so encoded wouldn't match) + return true; + } + + // Stage 2: Long needle - verify with full string comparison + let hash = self.state.hash_one(needle_bytes); + self.long_string_table + .find(hash, |&idx| { + // SAFETY: idx was stored in try_new from valid indices into in_array + unsafe { haystack_arr.value_unchecked(idx) }.as_bytes() + == needle_bytes + }) + .is_some() + }, + )) + } +} + +/// Creates a two-stage filter for Utf8/LargeUtf8 arrays. +/// Returns true if all non-null strings in a Utf8/LargeUtf8 array are ≤12 bytes. +/// When false, the two-stage filter's Stage 1 cannot definitively match and the +/// encoding overhead regresses performance vs the generic fallback. +pub(crate) fn utf8_all_short_strings(array: &dyn Array) -> bool { + use arrow::array::GenericStringArray; + use arrow::datatypes::DataType; + match array.data_type() { + DataType::Utf8 => utf8_all_short_strings_impl::( + array + .as_any() + .downcast_ref::>() + .unwrap(), + ), + DataType::LargeUtf8 => utf8_all_short_strings_impl::( + array + .as_any() + .downcast_ref::>() + .unwrap(), + ), + _ => false, + } +} + +fn utf8_all_short_strings_impl( + arr: &arrow::array::GenericStringArray, +) -> bool { + (0..arr.len()).all(|i| arr.is_null(i) || arr.value(i).len() <= INLINE_STRING_LEN) +} + +pub(crate) fn make_utf8_two_stage_filter( + in_array: ArrayRef, +) -> Result> { + use arrow::datatypes::DataType; + match in_array.data_type() { + DataType::Utf8 => Ok(Arc::new(Utf8TwoStageFilter::::try_new(in_array)?)), + DataType::LargeUtf8 => { + Ok(Arc::new(Utf8TwoStageFilter::::try_new(in_array)?)) + } + dt => datafusion_common::exec_err!( + "Unsupported data type for Utf8 two-stage filter: {dt}" + ), + } +}