diff --git a/rust/lance-core/src/datatypes/schema.rs b/rust/lance-core/src/datatypes/schema.rs index 242dea3315b..cf225fe2dd2 100644 --- a/rust/lance-core/src/datatypes/schema.rs +++ b/rust/lance-core/src/datatypes/schema.rs @@ -412,7 +412,7 @@ impl Schema { let mut fields = vec![]; for field in self.fields.iter() { if let Some(other_field) = other.field(&field.name) { - if field.data_type().is_struct() { + if field.data_type().is_nested() { if let Some(f) = field.exclude(other_field) { fields.push(f) } diff --git a/rust/lance-encoding/src/version.rs b/rust/lance-encoding/src/version.rs index 96fd7b6d7fe..726f36ec3cb 100644 --- a/rust/lance-encoding/src/version.rs +++ b/rust/lance-encoding/src/version.rs @@ -3,6 +3,8 @@ use std::str::FromStr; +use lance_arrow::DataTypeExt; +use lance_core::datatypes::Field; use lance_core::{Error, Result}; use snafu::location; @@ -85,6 +87,14 @@ impl LanceFileVersion { pub fn support_add_sub_column(&self) -> bool { self > &Self::V2_1 } + + pub fn support_remove_sub_column(&self, field: &Field) -> bool { + if self <= &Self::V2_1 { + field.data_type().is_struct() + } else { + field.data_type().is_nested() + } + } } impl std::fmt::Display for LanceFileVersion { diff --git a/rust/lance/src/dataset/schema_evolution.rs b/rust/lance/src/dataset/schema_evolution.rs index c4df5bd8f4b..da5ea5b1688 100644 --- a/rust/lance/src/dataset/schema_evolution.rs +++ b/rust/lance/src/dataset/schema_evolution.rs @@ -702,8 +702,9 @@ pub(super) async fn drop_columns(dataset: &mut Dataset, columns: &[&str]) -> Res } } + let version = dataset.manifest.data_storage_format.lance_file_version()?; let columns_to_remove = dataset.manifest.schema.project(columns)?; - let new_schema = dataset.manifest.schema.exclude(columns_to_remove)?; + let new_schema = exclude(&dataset.manifest.schema, &columns_to_remove, &version)?; if new_schema.fields.is_empty() { return Err(Error::invalid_input( @@ -725,15 +726,41 @@ pub(super) async fn drop_columns(dataset: &mut Dataset, columns: &[&str]) -> Res Ok(()) } +/// Exclude the fields from `other` Schema, and returns a new Schema. +pub fn exclude(source: &Schema, other: &Schema, version: &LanceFileVersion) -> Result { + let other: Schema = other.try_into().map_err(|_| Error::Schema { + message: "The other schema is not compatible with this schema".to_string(), + location: location!(), + })?; + let mut fields = vec![]; + for field in source.fields.iter() { + if let Some(other_field) = other.field(&field.name) { + if version.support_remove_sub_column(field) { + if let Some(f) = field.exclude(other_field) { + fields.push(f) + } + } + } else { + fields.push(field.clone()); + } + } + Ok(Schema { + fields, + metadata: source.metadata.clone(), + }) +} + #[cfg(test)] mod test { use std::collections::HashMap; use std::sync::Mutex; use crate::dataset::WriteParams; + use arrow_array::{ + ArrayRef, Int32Array, ListArray, RecordBatchIterator, StringArray, StructArray, + }; use super::*; - use arrow_array::{Int32Array, RecordBatchIterator}; use arrow_schema::Fields as ArrowFields; use lance_core::utils::tempfile::TempStrDir; use lance_file::version::LanceFileVersion; @@ -1230,6 +1257,195 @@ mod test { Ok(()) } + async fn prepare_dataset(version: LanceFileVersion) -> Result { + // id: int32 + // people: list> + let person_struct_type = DataType::Struct(ArrowFields::from(vec![ + ArrowField::new("name", DataType::Utf8, false), + ArrowField::new("age", DataType::Int32, false), + ArrowField::new("city", DataType::Utf8, false), + ])); + + let list_of_struct_type = DataType::List(Arc::new(ArrowField::new( + "item", + person_struct_type.clone(), + false, + ))); + + let schema = Arc::new(ArrowSchema::new_with_metadata( + vec![ + ArrowField::new("id", DataType::Int32, false), + ArrowField::new("people", list_of_struct_type.clone(), false), + ], + HashMap::::new(), + )); + + // Data: 3 rows, people is a list of 2, 3, 1 structs + let all_names = StringArray::from(vec!["Alice", "Bob", "Charlie", "David", "Eve", "Frank"]); + let all_ages = Int32Array::from(vec![25, 30, 35, 28, 32, 40]); + let all_cities = StringArray::from(vec![ + "Beijing", + "Shanghai", + "Guangzhou", + "Shenzhen", + "Hangzhou", + "Chengdu", + ]); + let all_struct = StructArray::new( + ArrowFields::from(vec![ + ArrowField::new("name", DataType::Utf8, false), + ArrowField::new("age", DataType::Int32, false), + ArrowField::new("city", DataType::Utf8, false), + ]), + vec![ + Arc::new(all_names) as ArrayRef, + Arc::new(all_ages) as ArrayRef, + Arc::new(all_cities) as ArrayRef, + ], + None, + ); + + let all_people = ListArray::new( + Arc::new(ArrowField::new("item", person_struct_type, false)), + arrow_buffer::OffsetBuffer::new(arrow_buffer::ScalarBuffer::from(vec![ + 0i32, 2i32, 5i32, 6i32, + ])), + Arc::new(all_struct), + None, + ); + + let ids = Int32Array::from(vec![1, 2, 3]); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(ids) as ArrayRef, Arc::new(all_people) as ArrayRef], + )?; + + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let dataset = Dataset::write( + reader, + "memory://test", + Some(WriteParams { + data_storage_version: Some(version), + ..Default::default() + }), + ) + .await?; + + // Verify schema + assert_eq!(dataset.schema().fields.len(), 2); + assert_eq!(dataset.schema().fields[0].name, "id"); + assert_eq!(dataset.schema().fields[1].name, "people"); + + Ok(dataset) + } + + #[rstest] + #[tokio::test] + async fn test_drop_list_struct_sub_columns_legacy( + #[values( + LanceFileVersion::Legacy, + LanceFileVersion::V2_0, + LanceFileVersion::V2_1 + )] + version: LanceFileVersion, + ) -> Result<()> { + let mut dataset = prepare_dataset(version).await?; + + // drop sub-column city from list(struct) + dataset.drop_columns(&["people.item.city"]).await?; + dataset.validate().await?; + + // people column has been fully removed + assert_eq!(dataset.schema().fields.len(), 1); + assert_eq!(dataset.schema().fields[0].name, "id"); + + Ok(()) + } + + #[rstest] + #[tokio::test] + async fn test_drop_list_struct_sub_columns( + #[values(LanceFileVersion::V2_2)] version: LanceFileVersion, + ) -> Result<()> { + let mut dataset = prepare_dataset(version).await?; + + // drop sub-column city from list(struct) + dataset.drop_columns(&["people.item.city"]).await?; + dataset.validate().await?; + + // people.item only contains name, age + let expected_schema = ArrowSchema::new_with_metadata( + vec![ + ArrowField::new("id", DataType::Int32, false), + ArrowField::new( + "people", + DataType::List(Arc::new(ArrowField::new( + "item", + DataType::Struct(ArrowFields::from(vec![ + ArrowField::new("name", DataType::Utf8, false), + ArrowField::new("age", DataType::Int32, false), + ])), + false, + ))), + false, + ), + ], + HashMap::::new(), + ); + assert_eq!(ArrowSchema::from(dataset.schema()), expected_schema); + + // Verify data + let batch = dataset.scan().try_into_batch().await?; + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 2); + + let list_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let list_value = list_array.value(0); + let struct_array = list_value.as_any().downcast_ref::().unwrap(); + assert!(struct_array.column_by_name("city").is_none()); + + Ok(()) + } + + #[test] + fn test_exclude_fields() { + let arrow_schema = ArrowSchema::new(vec![ + ArrowField::new("a", DataType::Int32, false), + ArrowField::new( + "b", + DataType::Struct(ArrowFields::from(vec![ + ArrowField::new("f1", DataType::Utf8, true), + ArrowField::new("f2", DataType::Boolean, false), + ArrowField::new("f3", DataType::Float32, false), + ])), + true, + ), + ArrowField::new("c", DataType::Float64, false), + ]); + let schema = Schema::try_from(&arrow_schema).unwrap(); + + let projection = schema.project(&["a", "b.f2", "b.f3"]).unwrap(); + let excluded = exclude(&schema, &projection, &LanceFileVersion::V2_2).unwrap(); + + let expected_arrow_schema = ArrowSchema::new(vec![ + ArrowField::new( + "b", + DataType::Struct(ArrowFields::from(vec![ArrowField::new( + "f1", + DataType::Utf8, + true, + )])), + true, + ), + ArrowField::new("c", DataType::Float64, false), + ]); + assert_eq!(ArrowSchema::from(&excluded), expected_arrow_schema); + } + #[rstest] #[tokio::test] async fn test_rename_columns(