Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/lance-core/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ impl Schema {
let mut fields = vec![];
for field in self.fields.iter() {
if let Some(other_field) = other.field(&field.name) {
if field.data_type().is_struct() {
if field.data_type().is_nested() {
if let Some(f) = field.exclude(other_field) {
fields.push(f)
}
Expand Down
10 changes: 10 additions & 0 deletions rust/lance-encoding/src/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

use std::str::FromStr;

use lance_arrow::DataTypeExt;
use lance_core::datatypes::Field;
use lance_core::{Error, Result};
use snafu::location;

Expand Down Expand Up @@ -85,6 +87,14 @@ impl LanceFileVersion {
pub fn support_add_sub_column(&self) -> bool {
self > &Self::V2_1
}

pub fn support_remove_sub_column(&self, field: &Field) -> bool {
if self <= &Self::V2_1 {
field.data_type().is_struct()
} else {
field.data_type().is_nested()
}
}
}

impl std::fmt::Display for LanceFileVersion {
Expand Down
220 changes: 218 additions & 2 deletions rust/lance/src/dataset/schema_evolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,9 @@ pub(super) async fn drop_columns(dataset: &mut Dataset, columns: &[&str]) -> Res
}
}

let version = dataset.manifest.data_storage_format.lance_file_version()?;
let columns_to_remove = dataset.manifest.schema.project(columns)?;
let new_schema = dataset.manifest.schema.exclude(columns_to_remove)?;
let new_schema = exclude(&dataset.manifest.schema, &columns_to_remove, &version)?;
Comment thread
wojiaodoubao marked this conversation as resolved.

if new_schema.fields.is_empty() {
return Err(Error::invalid_input(
Expand All @@ -725,15 +726,41 @@ pub(super) async fn drop_columns(dataset: &mut Dataset, columns: &[&str]) -> Res
Ok(())
}

/// Exclude the fields from `other` Schema, and returns a new Schema.
pub fn exclude(source: &Schema, other: &Schema, version: &LanceFileVersion) -> Result<Schema> {
let other: Schema = other.try_into().map_err(|_| Error::Schema {
message: "The other schema is not compatible with this schema".to_string(),
location: location!(),
})?;
let mut fields = vec![];
for field in source.fields.iter() {
if let Some(other_field) = other.field(&field.name) {
if version.support_remove_sub_column(field) {
if let Some(f) = field.exclude(other_field) {
fields.push(f)
}
}
} else {
fields.push(field.clone());
}
}
Ok(Schema {
fields,
metadata: source.metadata.clone(),
})
}

#[cfg(test)]
mod test {
use std::collections::HashMap;
use std::sync::Mutex;

use crate::dataset::WriteParams;
use arrow_array::{
ArrayRef, Int32Array, ListArray, RecordBatchIterator, StringArray, StructArray,
};

use super::*;
use arrow_array::{Int32Array, RecordBatchIterator};
use arrow_schema::Fields as ArrowFields;
use lance_core::utils::tempfile::TempStrDir;
use lance_file::version::LanceFileVersion;
Expand Down Expand Up @@ -1230,6 +1257,195 @@ mod test {
Ok(())
}

async fn prepare_dataset(version: LanceFileVersion) -> Result<Dataset> {
// id: int32
// people: list<struct<name: utf8, age: int32, city: utf8>>
let person_struct_type = DataType::Struct(ArrowFields::from(vec![
ArrowField::new("name", DataType::Utf8, false),
ArrowField::new("age", DataType::Int32, false),
ArrowField::new("city", DataType::Utf8, false),
]));

let list_of_struct_type = DataType::List(Arc::new(ArrowField::new(
"item",
person_struct_type.clone(),
false,
)));

let schema = Arc::new(ArrowSchema::new_with_metadata(
vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new("people", list_of_struct_type.clone(), false),
],
HashMap::<String, String>::new(),
));

// Data: 3 rows, people is a list of 2, 3, 1 structs
let all_names = StringArray::from(vec!["Alice", "Bob", "Charlie", "David", "Eve", "Frank"]);
let all_ages = Int32Array::from(vec![25, 30, 35, 28, 32, 40]);
let all_cities = StringArray::from(vec![
"Beijing",
"Shanghai",
"Guangzhou",
"Shenzhen",
"Hangzhou",
"Chengdu",
]);
let all_struct = StructArray::new(
ArrowFields::from(vec![
ArrowField::new("name", DataType::Utf8, false),
ArrowField::new("age", DataType::Int32, false),
ArrowField::new("city", DataType::Utf8, false),
]),
vec![
Arc::new(all_names) as ArrayRef,
Arc::new(all_ages) as ArrayRef,
Arc::new(all_cities) as ArrayRef,
],
None,
);

let all_people = ListArray::new(
Arc::new(ArrowField::new("item", person_struct_type, false)),
arrow_buffer::OffsetBuffer::new(arrow_buffer::ScalarBuffer::from(vec![
0i32, 2i32, 5i32, 6i32,
])),
Arc::new(all_struct),
None,
);

let ids = Int32Array::from(vec![1, 2, 3]);
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(ids) as ArrayRef, Arc::new(all_people) as ArrayRef],
)?;

let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let dataset = Dataset::write(
reader,
"memory://test",
Some(WriteParams {
data_storage_version: Some(version),
..Default::default()
}),
)
.await?;

// Verify schema
assert_eq!(dataset.schema().fields.len(), 2);
assert_eq!(dataset.schema().fields[0].name, "id");
assert_eq!(dataset.schema().fields[1].name, "people");

Ok(dataset)
}

#[rstest]
#[tokio::test]
async fn test_drop_list_struct_sub_columns_legacy(
#[values(
LanceFileVersion::Legacy,
LanceFileVersion::V2_0,
LanceFileVersion::V2_1
)]
version: LanceFileVersion,
) -> Result<()> {
let mut dataset = prepare_dataset(version).await?;

// drop sub-column city from list(struct)
dataset.drop_columns(&["people.item.city"]).await?;
dataset.validate().await?;

// people column has been fully removed
assert_eq!(dataset.schema().fields.len(), 1);
assert_eq!(dataset.schema().fields[0].name, "id");

Ok(())
}

#[rstest]
#[tokio::test]
async fn test_drop_list_struct_sub_columns(
#[values(LanceFileVersion::V2_2)] version: LanceFileVersion,
) -> Result<()> {
let mut dataset = prepare_dataset(version).await?;

// drop sub-column city from list(struct)
dataset.drop_columns(&["people.item.city"]).await?;
dataset.validate().await?;

// people.item only contains name, age
let expected_schema = ArrowSchema::new_with_metadata(
vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new(
"people",
DataType::List(Arc::new(ArrowField::new(
"item",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("name", DataType::Utf8, false),
ArrowField::new("age", DataType::Int32, false),
])),
false,
))),
false,
),
],
HashMap::<String, String>::new(),
);
assert_eq!(ArrowSchema::from(dataset.schema()), expected_schema);

// Verify data
let batch = dataset.scan().try_into_batch().await?;
assert_eq!(batch.num_rows(), 3);
assert_eq!(batch.num_columns(), 2);

let list_array = batch
.column(1)
.as_any()
.downcast_ref::<ListArray>()
.unwrap();
let list_value = list_array.value(0);
let struct_array = list_value.as_any().downcast_ref::<StructArray>().unwrap();
assert!(struct_array.column_by_name("city").is_none());

Ok(())
}

#[test]
fn test_exclude_fields() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();

let projection = schema.project(&["a", "b.f2", "b.f3"]).unwrap();
let excluded = exclude(&schema, &projection, &LanceFileVersion::V2_2).unwrap();

let expected_arrow_schema = ArrowSchema::new(vec![
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"f1",
DataType::Utf8,
true,
)])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
assert_eq!(ArrowSchema::from(&excluded), expected_arrow_schema);
}

#[rstest]
#[tokio::test]
async fn test_rename_columns(
Expand Down