diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index 66a30a903f4..e3e4757aa5e 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -2198,6 +2198,9 @@ impl Merger { mod tests { use super::*; use crate::dataset::scanner::ColumnOrdering; + use crate::dataset::write::merge_insert::inserted_rows::{ + extract_key_value_from_batch, KeyExistenceFilter, KeyExistenceFilterBuilder, + }; use crate::index::vector::VectorIndexParams; use crate::io::commit::read_transaction_file; use crate::{ @@ -2208,13 +2211,15 @@ mod tests { FragmentRowCount, ThrottledStoreWrapper, }, }; + use arrow_array::builder::{ListBuilder, StringBuilder}; use arrow_array::types::Float32Type; use arrow_array::RecordBatch; use arrow_array::{ types::{Int32Type, UInt32Type}, - FixedSizeListArray, Float32Array, Float64Array, Int32Array, Int64Array, - RecordBatchIterator, RecordBatchReader, StringArray, UInt32Array, + Array, FixedSizeListArray, Float32Array, Float64Array, Int32Array, Int64Array, ListArray, + RecordBatchIterator, RecordBatchReader, StringArray, StructArray, UInt32Array, }; + use arrow_buffer::{OffsetBuffer, ScalarBuffer}; use arrow_schema::{DataType, Field, Schema}; use arrow_select::concat::concat_batches; use datafusion::common::Column; @@ -4553,6 +4558,443 @@ mod tests { ); } + #[test] + fn test_concurrent_insert_different_new_list_key() { + // Schema for list(string) key column "tags". + let tags_field = Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + false, + ); + let schema = Arc::new(Schema::new(vec![tags_field])); + + // Build two batches inserting list key ["a", "b"] and ["c", "d"]. + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.append_value(["a", "b"].iter().copied().map(Some)); + let tags_array1 = builder.finish(); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(tags_array1)]).unwrap(); + + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.append_value(["c", "d"].iter().copied().map(Some)); + let tags_array2 = builder.finish(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(tags_array2)]).unwrap(); + + // Build bloom filters for the list keys. + let field_ids = vec![0_i32]; + let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone()); + let mut builder2 = KeyExistenceFilterBuilder::new(field_ids); + + let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("tags")]) + .expect("first batch should produce key"); + let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("tags")]) + .expect("second batch should produce key"); + + builder1.insert(key1).unwrap(); + builder2.insert(key2).unwrap(); + let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1); + let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2); + + let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap(); + assert!( + !has_intersection, + "Expected bloom filters not intersect for different list(string) keys", + ); + assert!( + !might_be_fp, + "Bloom filter intersection should be definitively not conflict", + ); + } + + #[test] + fn test_concurrent_insert_same_new_list_key() { + // Schema for list(string) key column "tags". + let tags_field = Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + false, + ); + let schema = Arc::new(Schema::new(vec![tags_field])); + + // Build two batches both inserting the same list key ["a", "b"]. + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.append_value(["a", "b"].iter().copied().map(Some)); + let tags_array1 = builder.finish(); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(tags_array1)]).unwrap(); + + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.append_value(["a", "b"].iter().copied().map(Some)); + let tags_array2 = builder.finish(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(tags_array2)]).unwrap(); + + // Build bloom filters for the list key. + let field_ids = vec![0_i32]; + let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone()); + let mut builder2 = KeyExistenceFilterBuilder::new(field_ids); + + let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("tags")]) + .expect("first batch should produce key"); + let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("tags")]) + .expect("second batch should produce key"); + + builder1.insert(key1).unwrap(); + builder2.insert(key2).unwrap(); + let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1); + let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2); + + let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap(); + assert!( + has_intersection, + "Expected bloom filters to intersect for identical list(string) keys", + ); + assert!( + might_be_fp, + "Bloom filter intersection should be treated as potential conflict", + ); + } + + #[test] + fn test_concurrent_insert_same_new_nested_list_key() { + // Build nested list(list(string)) value [["a", "b"], ["c"]] for the "tags" column. + let nested_tags = make_nested_array(&[["a", "b"].as_slice(), ["c"].as_slice()]); + let tags_field = Field::new("tags", nested_tags.data_type().clone(), false); + let nested_tags2 = make_nested_array(&[["a", "b"].as_slice(), ["c"].as_slice()]); + + let schema = Arc::new(Schema::new(vec![tags_field])); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(nested_tags)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(nested_tags2)]).unwrap(); + + // Build bloom filters for the nested list key. + let field_ids = vec![0_i32]; + let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone()); + let mut builder2 = KeyExistenceFilterBuilder::new(field_ids); + + let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("tags")]) + .expect("first batch should produce key"); + let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("tags")]) + .expect("second batch should produce key"); + + builder1.insert(key1).unwrap(); + builder2.insert(key2).unwrap(); + let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1); + let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2); + + let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap(); + assert!( + has_intersection, + "Expected bloom filters to intersect for identical nested list(list(string)) keys", + ); + assert!( + might_be_fp, + "Bloom filter intersection should be treated as potential conflict", + ); + } + + #[test] + fn test_concurrent_insert_different_new_struct_key() { + let user_field = Field::new( + "user", + DataType::Struct( + vec![ + Field::new("first", DataType::Utf8, false), + Field::new("last", DataType::Utf8, false), + ] + .into(), + ), + false, + ); + let schema = Arc::new(Schema::new(vec![user_field])); + + // Build two batches inserting different struct keys. + let struct_array1 = make_struct_array_first_last_name(vec!["alice"], vec!["smith"]); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(struct_array1)]).unwrap(); + + let struct_array2 = make_struct_array_first_last_name(vec!["bob"], vec!["jones"]); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(struct_array2)]).unwrap(); + + // Build bloom filters for the struct key. + let field_ids = vec![0_i32]; + let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone()); + let mut builder2 = KeyExistenceFilterBuilder::new(field_ids); + + let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("user")]) + .expect("first batch should produce key"); + let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("user")]) + .expect("second batch should produce key"); + + builder1.insert(key1).unwrap(); + builder2.insert(key2).unwrap(); + let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1); + let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2); + + let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap(); + assert!( + !has_intersection, + "Expected bloom filters not intersect for different struct keys", + ); + assert!( + !might_be_fp, + "Bloom filter intersection should be definitively not conflict", + ); + } + + #[test] + fn test_concurrent_insert_same_new_struct_key() { + let user_field = Field::new( + "user", + DataType::Struct( + vec![ + Field::new("first", DataType::Utf8, false), + Field::new("last", DataType::Utf8, false), + ] + .into(), + ), + false, + ); + let schema = Arc::new(Schema::new(vec![user_field])); + + // Build two batches both inserting the same struct key {first: "alice", last: "smith"}. + let struct_array1 = make_struct_array_first_last_name(vec!["alice"], vec!["smith"]); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(struct_array1)]).unwrap(); + + let struct_array2 = make_struct_array_first_last_name(vec!["alice"], vec!["smith"]); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(struct_array2)]).unwrap(); + + // Build bloom filters for the struct key. + let field_ids = vec![0_i32]; + let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone()); + let mut builder2 = KeyExistenceFilterBuilder::new(field_ids); + + let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("user")]) + .expect("first batch should produce key"); + let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("user")]) + .expect("second batch should produce key"); + + builder1.insert(key1).unwrap(); + builder2.insert(key2).unwrap(); + let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1); + let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2); + + let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap(); + assert!( + has_intersection, + "Expected bloom filters to intersect for identical struct keys", + ); + assert!( + might_be_fp, + "Bloom filter intersection should be treated as potential conflict", + ); + } + + #[test] + fn test_concurrent_insert_same_new_nested_struct_key() { + // Build nested struct value {address: {city: "seattle", zip: 98101}} for the "user" column. + let outer_struct = make_nested_struct_array_city_zip("seattle", 98101); + let user_field = Field::new("user", outer_struct.data_type().clone(), false); + let schema = Arc::new(Schema::new(vec![user_field])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(outer_struct)]).unwrap(); + + let outer_struct2 = make_nested_struct_array_city_zip("seattle", 98101); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(outer_struct2)]).unwrap(); + + // Build bloom filters for the nested struct key. + let field_ids = vec![0_i32]; + let mut builder1 = KeyExistenceFilterBuilder::new(field_ids.clone()); + let mut builder2 = KeyExistenceFilterBuilder::new(field_ids); + + let key1 = extract_key_value_from_batch(&batch1, 0, &[String::from("user")]) + .expect("first batch should produce key"); + let key2 = extract_key_value_from_batch(&batch2, 0, &[String::from("user")]) + .expect("second batch should produce key"); + + builder1.insert(key1).unwrap(); + builder2.insert(key2).unwrap(); + let filter1 = KeyExistenceFilter::from_bloom_filter(&builder1); + let filter2 = KeyExistenceFilter::from_bloom_filter(&builder2); + + let (has_intersection, might_be_fp) = filter1.intersects(&filter2).unwrap(); + assert!( + has_intersection, + "Expected bloom filters to intersect for identical nested struct keys", + ); + assert!( + might_be_fp, + "Bloom filter intersection should be treated as potential conflict", + ); + } + + /// End-to-end test for merge_insert using a struct-typed key column. + #[tokio::test] + async fn test_merge_insert_struct_key_upsert() { + let user_field = Field::new( + "user", + DataType::Struct( + vec![ + Field::new("first", DataType::Utf8, false), + Field::new("last", DataType::Utf8, false), + ] + .into(), + ), + false, + ); + let schema = Arc::new(Schema::new(vec![ + user_field, + Field::new("value", DataType::UInt32, false), + ])); + + // Initial dataset: + // (alice, smith) -> 1 + // (bob, jones) -> 1 + // (carla, doe) -> 1 + let user_array = make_struct_array_first_last_name( + vec!["alice", "bob", "carla"], + vec!["smith", "jones", "doe"], + ); + let values = UInt32Array::from(vec![1, 1, 1]); + let initial_batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(user_array), Arc::new(values)]) + .unwrap(); + + let test_uri = "memory://test_merge_insert_struct_key.lance"; + let dataset = Dataset::write( + RecordBatchIterator::new(vec![Ok(initial_batch)], schema.clone()), + test_uri, + None, + ) + .await + .unwrap(); + let dataset = Arc::new(dataset); + + // New data: update alice, insert david + let new_user_array = + make_struct_array_first_last_name(vec!["alice", "david"], vec!["smith", "brown"]); + let new_values = UInt32Array::from(vec![10, 2]); + let new_batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(new_user_array), Arc::new(new_values)], + ) + .unwrap(); + + let reader = RecordBatchIterator::new([Ok(new_batch)], schema.clone()); + let (merged_ds, stats) = MergeInsertBuilder::try_new(dataset, vec!["user".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap() + .execute(reader_to_stream(Box::new(reader))) + .await + .unwrap(); + + assert_eq!(stats.num_updated_rows, 1); + assert_eq!(stats.num_inserted_rows, 1); + assert_eq!(stats.num_deleted_rows, 0); + + let result = merged_ds.scan().try_into_batch().await.unwrap(); + let user_col = result + .column_by_name("user") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let first = user_col + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let last = user_col + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let values = result + .column_by_name("value") + .unwrap() + .as_primitive::(); + + let mut rows = Vec::new(); + for i in 0..result.num_rows() { + rows.push(( + first.value(i).to_string(), + last.value(i).to_string(), + values.value(i), + )); + } + rows.sort(); + + assert_eq!( + rows, + vec![ + ("alice".to_string(), "smith".to_string(), 10), + ("bob".to_string(), "jones".to_string(), 1), + ("carla".to_string(), "doe".to_string(), 1), + ("david".to_string(), "brown".to_string(), 2), + ], + ); + } + + fn make_struct_array_first_last_name(first: Vec<&str>, last: Vec<&str>) -> StructArray { + let first = StringArray::from(first); + let last = StringArray::from(last); + + StructArray::from(vec![ + ( + Arc::new(Field::new("first", DataType::Utf8, false)), + Arc::new(first) as Arc, + ), + ( + Arc::new(Field::new("last", DataType::Utf8, false)), + Arc::new(last) as Arc, + ), + ]) + } + + fn make_nested_struct_array_city_zip(city: &str, zip: i32) -> StructArray { + let city = StringArray::from(vec![city]); + let zip = Int32Array::from(vec![zip]); + + let inner_struct = StructArray::from(vec![ + ( + Arc::new(Field::new("city", DataType::Utf8, false)), + Arc::new(city) as Arc, + ), + ( + Arc::new(Field::new("zip", DataType::Int32, false)), + Arc::new(zip) as Arc, + ), + ]); + + StructArray::from(vec![( + Arc::new(Field::new( + "address", + inner_struct.data_type().clone(), + false, + )), + Arc::new(inner_struct) as Arc, + )]) + } + + fn make_nested_array(inner_lists: &[&[&str]]) -> ListArray { + let mut inner_builder = ListBuilder::new(StringBuilder::new()); + for inner in inner_lists { + inner_builder.append_value(inner.iter().map(|s| Some(*s))); + } + let inner_list_array = inner_builder.finish(); + + let offsets = ScalarBuffer::::from(vec![0, inner_list_array.len() as i32]); + let offsets = OffsetBuffer::new(offsets); + ListArray::new( + Arc::new(Field::new( + "item", + inner_list_array.data_type().clone(), + inner_list_array.nulls().is_some(), + )), + offsets, + Arc::new(inner_list_array), + None, + ) + } + /// Test that merge_insert with bloom filter fails when committing against /// an Update transaction that doesn't have a filter. We can't determine if /// the Update operation conflicted with our inserted rows. diff --git a/rust/lance/src/dataset/write/merge_insert/inserted_rows.rs b/rust/lance/src/dataset/write/merge_insert/inserted_rows.rs index 8b9073fefc3..f4ccfa1195e 100644 --- a/rust/lance/src/dataset/write/merge_insert/inserted_rows.rs +++ b/rust/lance/src/dataset/write/merge_insert/inserted_rows.rs @@ -8,7 +8,10 @@ use std::collections::HashSet; use std::hash::{Hash, Hasher}; use arrow_array::cast::AsArray; -use arrow_array::{BinaryArray, LargeBinaryArray, LargeStringArray, RecordBatch, StringArray}; +use arrow_array::{ + Array, BinaryArray, LargeBinaryArray, LargeListArray, LargeStringArray, ListArray, RecordBatch, + StringArray, StructArray, +}; use arrow_schema::DataType; use deepsize::DeepSizeOf; use lance_core::Result; @@ -27,6 +30,8 @@ pub enum KeyValue { Int64(i64), UInt64(u64), Binary(Vec), + List(Vec), + Struct(Vec), Composite(Vec), } @@ -37,7 +42,7 @@ impl KeyValue { Self::Int64(i) => i.to_le_bytes().to_vec(), Self::UInt64(u) => u.to_le_bytes().to_vec(), Self::Binary(b) => b.clone(), - Self::Composite(values) => { + Self::List(values) | Self::Struct(values) | Self::Composite(values) => { let mut result = Vec::new(); for value in values { result.extend_from_slice(&value.to_bytes()); @@ -289,49 +294,424 @@ pub fn extract_key_value_from_batch( return None; } - let key_part = match column.data_type() { - DataType::Utf8 => { - let arr = column.as_any().downcast_ref::()?; - KeyValue::String(arr.value(row_idx).to_string()) + let key_part = extract_key_value(column, row_idx)?; + parts.push(key_part); + } + + if parts.is_empty() { + None + } else if parts.len() == 1 { + Some(parts.into_iter().next().unwrap()) + } else { + Some(KeyValue::Composite(parts)) + } +} + +fn extract_key_value(array: &dyn Array, row_idx: usize) -> Option { + let v = match array.data_type() { + DataType::Utf8 => { + let arr = array.as_any().downcast_ref::()?; + KeyValue::String(arr.value(row_idx).to_string()) + } + DataType::LargeUtf8 => { + let arr = array.as_any().downcast_ref::()?; + KeyValue::String(arr.value(row_idx).to_string()) + } + DataType::UInt64 => { + let arr = array.as_primitive::(); + KeyValue::UInt64(arr.value(row_idx)) + } + DataType::Int64 => { + let arr = array.as_primitive::(); + KeyValue::Int64(arr.value(row_idx)) + } + DataType::UInt32 => { + let arr = array.as_primitive::(); + KeyValue::UInt64(arr.value(row_idx) as u64) + } + DataType::Int32 => { + let arr = array.as_primitive::(); + KeyValue::Int64(arr.value(row_idx) as i64) + } + DataType::Binary => { + let arr = array.as_any().downcast_ref::()?; + KeyValue::Binary(arr.value(row_idx).to_vec()) + } + DataType::LargeBinary => { + let arr = array.as_any().downcast_ref::()?; + KeyValue::Binary(arr.value(row_idx).to_vec()) + } + DataType::List(_) => { + let list_array = array.as_any().downcast_ref::().unwrap(); + let values = list_array.value(row_idx); + + let mut elements = Vec::with_capacity(values.len()); + for i in 0..values.len() { + if values.is_null(i) { + return None; + } + let element = extract_key_value(&values, i)?; + elements.push(element); } - DataType::LargeUtf8 => { - let arr = column.as_any().downcast_ref::()?; - KeyValue::String(arr.value(row_idx).to_string()) + KeyValue::List(elements) + } + DataType::LargeList(_) => { + let list_array = array.as_any().downcast_ref::().unwrap(); + let values = list_array.value(row_idx); + + let mut elements = Vec::with_capacity(values.len()); + for i in 0..values.len() { + if values.is_null(i) { + return None; + } + let element = extract_key_value(&values, i)?; + elements.push(element); + } + KeyValue::List(elements) + } + DataType::Struct(_) => { + let struct_array = array.as_any().downcast_ref::()?; + let mut elements = Vec::with_capacity(struct_array.num_columns()); + for i in 0..struct_array.num_columns() { + let child = struct_array.column(i); + if child.is_null(row_idx) { + return None; + } + let field_value = extract_key_value(child.as_ref(), row_idx)?; + elements.push(field_value); + } + KeyValue::Struct(elements) + } + _ => return None, + }; + Some(v) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow_array::builder::{Int32Builder, ListBuilder, StringBuilder}; + use arrow_array::{Int32Array, RecordBatch, StringArray, StructArray}; + use arrow_schema::{Field, Schema}; + + #[test] + fn test_extract_key_value_from_batch_list_int() { + let values_builder = Int32Builder::new(); + let mut list_builder = ListBuilder::new(values_builder); + + list_builder.append_value([Some(1), Some(2)]); + list_builder.append_value([Some(3), Some(4), Some(5)]); + + let list_array = list_builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + list_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(list_array)]) + .expect("batch should be valid"); + + let key0 = extract_key_value_from_batch(&batch, 0, &[String::from("id")]) + .expect("first row should produce a key"); + let key1 = extract_key_value_from_batch(&batch, 1, &[String::from("id")]) + .expect("second row should produce a key"); + + match &key0 { + KeyValue::List(values) => { + assert_eq!(values.len(), 2); + assert_eq!(values[0], KeyValue::Int64(1)); + assert_eq!(values[1], KeyValue::Int64(2)); + } + other => panic!("expected list key, got {:?}", other), + } + + match &key1 { + KeyValue::List(values) => { + assert_eq!(values.len(), 3); + assert_eq!(values[0], KeyValue::Int64(3)); + assert_eq!(values[1], KeyValue::Int64(4)); + assert_eq!(values[2], KeyValue::Int64(5)); + } + other => panic!("expected list key, got {:?}", other), + } + + assert_ne!( + key0.hash_value(), + key1.hash_value(), + "different list values should hash differently", + ); + } + + #[test] + fn test_extract_key_value_from_batch_empty_list() { + let values_builder = Int32Builder::new(); + let mut list_builder = ListBuilder::new(values_builder); + + list_builder.append_value(std::iter::empty::>()); + + let list_array = list_builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + list_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(list_array)]) + .expect("batch should be valid"); + + let key = extract_key_value_from_batch(&batch, 0, &[String::from("id")]) + .expect("empty list should still produce a key"); + + match key { + KeyValue::List(values) => { + assert!(values.is_empty(), "expected empty list"); + } + other => panic!("expected list key, got {:?}", other), + } + } + + #[test] + fn test_extract_key_value_from_batch_list_utf8() { + let values_builder = StringBuilder::new(); + let mut list_builder = ListBuilder::new(values_builder); + + list_builder.append_value([Some("a"), Some("bc")]); + list_builder.append_value([Some("de")]); + + let list_array = list_builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + list_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(list_array)]) + .expect("batch should be valid"); + + let key0 = extract_key_value_from_batch(&batch, 0, &[String::from("id")]) + .expect("first row should produce a key"); + let key1 = extract_key_value_from_batch(&batch, 1, &[String::from("id")]) + .expect("second row should produce a key"); + + match &key0 { + KeyValue::List(values) => { + assert_eq!(values.len(), 2); + assert_eq!(values[0], KeyValue::String("a".to_string())); + assert_eq!(values[1], KeyValue::String("bc".to_string())); } - DataType::UInt64 => { - let arr = column.as_primitive::(); - KeyValue::UInt64(arr.value(row_idx)) + other => panic!("expected list key, got {:?}", other), + } + + match &key1 { + KeyValue::List(values) => { + assert_eq!(values.len(), 1); + assert_eq!(values[0], KeyValue::String("de".to_string())); } - DataType::Int64 => { - let arr = column.as_primitive::(); - KeyValue::Int64(arr.value(row_idx)) + other => panic!("expected list key, got {:?}", other), + } + + assert_ne!( + key0.hash_value(), + key1.hash_value(), + "different list values should hash differently", + ); + } + + #[test] + fn test_extract_key_value_from_batch_list_with_null_child() { + let values_builder = Int32Builder::new(); + let mut list_builder = ListBuilder::new(values_builder); + + list_builder.append_value([Some(1), Some(2)]); + list_builder.append_value([Some(3), None]); + + let list_array = list_builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + list_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(list_array)]) + .expect("batch should be valid"); + + let key0 = extract_key_value_from_batch(&batch, 0, &[String::from("id")]) + .expect("first row should produce a key"); + let key1 = extract_key_value_from_batch(&batch, 1, &[String::from("id")]); + + match &key0 { + KeyValue::List(values) => { + assert_eq!(values.len(), 2); + assert_eq!(values[0], KeyValue::Int64(1)); + assert_eq!(values[1], KeyValue::Int64(2)); } - DataType::UInt32 => { - let arr = column.as_primitive::(); - KeyValue::UInt64(arr.value(row_idx) as u64) + other => panic!("expected list key, got {:?}", other), + } + + assert!( + key1.is_none(), + "list row with a null child should not produce a key", + ); + } + + #[test] + fn test_extract_key_value_from_batch_struct_int() { + let a_values = Int32Array::from(vec![1, 3]); + let b_values = Int32Array::from(vec![2, 4]); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", arrow_schema::DataType::Int32, false)), + Arc::new(a_values) as Arc, + ), + ( + Arc::new(Field::new("b", arrow_schema::DataType::Int32, false)), + Arc::new(b_values) as Arc, + ), + ]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + struct_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(struct_array)]) + .expect("batch should be valid"); + + let key0 = extract_key_value_from_batch(&batch, 0, &[String::from("id")]) + .expect("first row should produce a key"); + let key1 = extract_key_value_from_batch(&batch, 1, &[String::from("id")]) + .expect("second row should produce a key"); + + match &key0 { + KeyValue::Struct(values) => { + assert_eq!(values.len(), 2); + assert_eq!(values[0], KeyValue::Int64(1)); + assert_eq!(values[1], KeyValue::Int64(2)); } - DataType::Int32 => { - let arr = column.as_primitive::(); - KeyValue::Int64(arr.value(row_idx) as i64) + other => panic!("expected struct key, got {:?}", other), + } + + match &key1 { + KeyValue::Struct(values) => { + assert_eq!(values.len(), 2); + assert_eq!(values[0], KeyValue::Int64(3)); + assert_eq!(values[1], KeyValue::Int64(4)); } - DataType::Binary => { - let arr = column.as_any().downcast_ref::()?; - KeyValue::Binary(arr.value(row_idx).to_vec()) + other => panic!("expected struct key, got {:?}", other), + } + + assert_ne!( + key0.hash_value(), + key1.hash_value(), + "different struct values should hash differently", + ); + } + + #[test] + fn test_extract_key_value_from_batch_struct_utf8() { + let first_names = StringArray::from(vec!["alice", "bob"]); + let last_names = StringArray::from(vec!["smith", "jones"]); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("first", arrow_schema::DataType::Utf8, false)), + Arc::new(first_names) as Arc, + ), + ( + Arc::new(Field::new("last", arrow_schema::DataType::Utf8, false)), + Arc::new(last_names) as Arc, + ), + ]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + struct_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(struct_array)]) + .expect("batch should be valid"); + + let key0 = extract_key_value_from_batch(&batch, 0, &[String::from("id")]) + .expect("first row should produce a key"); + let key1 = extract_key_value_from_batch(&batch, 1, &[String::from("id")]) + .expect("second row should produce a key"); + + match &key0 { + KeyValue::Struct(values) => { + assert_eq!(values.len(), 2); + assert_eq!(values[0], KeyValue::String("alice".to_string())); + assert_eq!(values[1], KeyValue::String("smith".to_string())); } - DataType::LargeBinary => { - let arr = column.as_any().downcast_ref::()?; - KeyValue::Binary(arr.value(row_idx).to_vec()) + other => panic!("expected struct key, got {:?}", other), + } + + match &key1 { + KeyValue::Struct(values) => { + assert_eq!(values.len(), 2); + assert_eq!(values[0], KeyValue::String("bob".to_string())); + assert_eq!(values[1], KeyValue::String("jones".to_string())); } - _ => return None, - }; - parts.push(key_part); + other => panic!("expected struct key, got {:?}", other), + } + + assert_ne!( + key0.hash_value(), + key1.hash_value(), + "different struct values should hash differently", + ); } - if parts.is_empty() { - None - } else if parts.len() == 1 { - Some(parts.into_iter().next().unwrap()) - } else { - Some(KeyValue::Composite(parts)) + #[test] + fn test_extract_key_value_from_batch_struct_with_null_child() { + let a_values = Int32Array::from(vec![Some(1), None]); + let b_values = Int32Array::from(vec![Some(2), Some(3)]); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", arrow_schema::DataType::Int32, true)), + Arc::new(a_values) as Arc, + ), + ( + Arc::new(Field::new("b", arrow_schema::DataType::Int32, true)), + Arc::new(b_values) as Arc, + ), + ]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + struct_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(struct_array)]) + .expect("batch should be valid"); + + let key0 = extract_key_value_from_batch(&batch, 0, &[String::from("id")]) + .expect("first row should produce a key"); + let key1 = extract_key_value_from_batch(&batch, 1, &[String::from("id")]); + + match &key0 { + KeyValue::Struct(values) => { + assert_eq!(values.len(), 2); + assert_eq!(values[0], KeyValue::Int64(1)); + assert_eq!(values[1], KeyValue::Int64(2)); + } + other => panic!("expected struct key, got {:?}", other), + } + + assert!( + key1.is_none(), + "struct row with a null child should not produce a key", + ); } }