diff --git a/rust/lance-arrow/src/lib.rs b/rust/lance-arrow/src/lib.rs index 29314b166e1..43d8fd2a5cd 100644 --- a/rust/lance-arrow/src/lib.rs +++ b/rust/lance-arrow/src/lib.rs @@ -27,7 +27,9 @@ pub mod schema; pub use schema::*; pub mod bfloat16; pub mod floats; +use crate::list::ListArrayExt; pub use floats::*; + pub mod cast; pub mod json; pub mod list; @@ -1308,8 +1310,8 @@ fn merge_with_schema( .unwrap(); let merged_values = merge_list_child_values( child_field.as_ref(), - left_list.values().clone(), - right_list.values().clone(), + left_list.trimmed_values(), + right_list.trimmed_values(), ); let merged_validity = merge_struct_validity(left_list.nulls(), right_list.nulls()); @@ -1333,8 +1335,8 @@ fn merge_with_schema( .unwrap(); let merged_values = merge_list_child_values( child_field.as_ref(), - left_list.values().clone(), - right_list.values().clone(), + left_list.trimmed_values(), + right_list.trimmed_values(), ); let merged_validity = merge_struct_validity(left_list.nulls(), right_list.nulls()); @@ -2050,4 +2052,179 @@ mod tests { assert!(count.is_null(0)); assert!(count.is_null(1)); } + + #[test] + fn test_merge_struct_lists() { + test_merge_struct_lists_generic::(); + } + + #[test] + fn test_merge_struct_large_lists() { + test_merge_struct_lists_generic::(); + } + + fn test_merge_struct_lists_generic() { + // left_list setup + let left_company_id = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + Some(10), + Some(11), + Some(12), + Some(13), + Some(14), + Some(15), + Some(16), + Some(17), + Some(18), + Some(19), + Some(20), + ])); + let left_count = Arc::new(Int32Array::from(vec![ + Some(10), + Some(20), + Some(30), + Some(40), + Some(50), + Some(60), + Some(70), + Some(80), + Some(90), + Some(100), + Some(110), + Some(120), + Some(130), + Some(140), + Some(150), + Some(160), + Some(170), + Some(180), + Some(190), + Some(200), + ])); + let left_struct = Arc::new(StructArray::new( + Fields::from(vec![ + Field::new("company_id", DataType::Int32, true), + Field::new("count", DataType::Int32, true), + ]), + vec![left_company_id, left_count], + None, + )); + + let left_list = Arc::new(GenericListArray::::new( + Arc::new(Field::new( + "item", + DataType::Struct(left_struct.fields().clone()), + true, + )), + OffsetBuffer::from_lengths([3, 1]), + left_struct.clone(), + None, + )); + + let left_list_struct = Arc::new(StructArray::new( + Fields::from(vec![Field::new( + "companies", + if O::IS_LARGE { + DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Struct(left_struct.fields().clone()), + true, + ))) + } else { + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(left_struct.fields().clone()), + true, + ))) + }, + true, + )]), + vec![left_list as ArrayRef], + None, + )); + + // right_list setup + let right_company_name = Arc::new(StringArray::from(vec![ + "Google", + "Microsoft", + "Apple", + "Facebook", + ])); + let right_struct = Arc::new(StructArray::new( + Fields::from(vec![Field::new("company_name", DataType::Utf8, true)]), + vec![right_company_name], + None, + )); + let right_list = Arc::new(GenericListArray::::new( + Arc::new(Field::new( + "item", + DataType::Struct(right_struct.fields().clone()), + true, + )), + OffsetBuffer::from_lengths([3, 1]), + right_struct.clone(), + None, + )); + + let right_list_struct = Arc::new(StructArray::new( + Fields::from(vec![Field::new( + "companies", + if O::IS_LARGE { + DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Struct(right_struct.fields().clone()), + true, + ))) + } else { + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(right_struct.fields().clone()), + true, + ))) + }, + true, + )]), + vec![right_list as ArrayRef], + None, + )); + + // prepare schema + let target_fields = Fields::from(vec![Field::new( + "companies", + if O::IS_LARGE { + DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("company_id", DataType::Int32, true), + Field::new("company_name", DataType::Utf8, true), + Field::new("count", DataType::Int32, true), + ])), + true, + ))) + } else { + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("company_id", DataType::Int32, true), + Field::new("company_name", DataType::Utf8, true), + Field::new("count", DataType::Int32, true), + ])), + true, + ))) + }, + true, + )]); + + // merge left_list and right_list + let merged_array = merge_with_schema(&left_list_struct, &right_list_struct, &target_fields); + assert_eq!(merged_array.len(), 2); + } }