Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 248 additions & 12 deletions rust/lance/src/dataset/schema_evolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::{
use crate::{io::exec::Planner, Error, Result};
use arrow::compute::can_cast_types;
use arrow::compute::CastOptions;
use arrow_array::{RecordBatch, RecordBatchReader};
use arrow_array::{Array, RecordBatch, RecordBatchReader};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use datafusion::execution::SendableRecordBatchStream;
use futures::stream::{StreamExt, TryStreamExt};
Expand All @@ -29,6 +29,49 @@ use optimize::{
ChainedNewColumnTransformOptimizer, NewColumnTransformOptimizer, SqlToAllNullsOptimizer,
};

async fn validate_no_nulls_before_making_non_nullable(dataset: &Dataset, path: &str) -> Result<()> {
let field = dataset.schema().field(path).ok_or_else(|| {
Error::invalid_input(
format!("Column \"{}\" does not exist in the dataset", path),
location!(),
)
})?;

if !field.nullable {
return Ok(());
}

let mut scanner = dataset.scan();
scanner.project(&[path])?;
let mut stream = scanner.try_into_stream().await?;
while let Some(batch) = stream.try_next().await? {
// `path` can be a nested path (e.g. "b.c") which will not be found by
// `RecordBatch::column_by_name`. We project exactly one column and validate it directly.
if batch.num_columns() != 1 {
return Err(Error::Internal {
message: format!(
"Expected exactly one column in validation scan for {}, got {}",
path,
batch.num_columns()
),
location: location!(),
});
}
let col = batch.column(0);
if col.null_count() > 0 {
return Err(Error::invalid_input(
format!(
"Column \"{}\" contains NULL values and cannot be made non-nullable",
path
),
location!(),
));
}
}

Ok(())
}

#[derive(Debug, Clone, PartialEq)]
pub struct BatchInfo {
pub fragment_id: u32,
Expand Down Expand Up @@ -522,8 +565,8 @@ pub(super) async fn alter_columns(
dataset: &mut Dataset,
alterations: &[ColumnAlteration],
) -> Result<()> {
// Validate we aren't making nullable columns non-nullable and that all
// the referenced columns actually exist.
// Validate referenced columns exist and enforce NOT NULL when tightening
// a column from nullable to non-nullable.
let mut new_schema = dataset.schema().clone();

// Mapping of old to new fields that need to be casted.
Expand All @@ -543,16 +586,8 @@ pub(super) async fn alter_columns(
})?;

if let Some(nullable) = alteration.nullable {
// TODO: in the future, we could check the values of the column to see if
// they are all non-null and thus the column could be made non-nullable.
if field_src.nullable && !nullable {
return Err(Error::invalid_input(
format!(
"Column \"{}\" is already nullable and thus cannot be made non-nullable",
alteration.path
),
location!(),
));
validate_no_nulls_before_making_non_nullable(dataset, &alteration.path).await?;
}
}

Expand Down Expand Up @@ -1563,6 +1598,207 @@ mod test {
Ok(())
}

#[rstest]
#[tokio::test]
async fn test_set_not_null_succeeds(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) -> Result<()> {
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"a",
DataType::Int32,
true,
)]));

let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values([1, 2, 3]))],
)?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(batch)], schema.clone()),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;

let original_fragments = dataset.fragments().to_vec();
dataset
.alter_columns(&[ColumnAlteration::new("a".into()).set_nullable(false)])
.await?;
dataset.validate().await?;

assert_eq!(dataset.manifest.version, 2);
assert_eq!(dataset.fragments().as_ref(), &original_fragments);
assert_eq!(
&ArrowSchema::from(dataset.schema()),
&ArrowSchema::new(vec![ArrowField::new("a", DataType::Int32, false)])
);

Ok(())
}

#[rstest]
#[tokio::test]
async fn test_set_not_null_succeeds_nested(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) -> Result<()> {
use arrow_array::{ArrayRef, StructArray};

let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"c",
DataType::Int32,
true,
)])),
false,
)]));

let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(StructArray::from(vec![(
Arc::new(ArrowField::new("c", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
)]))],
)?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(batch)], schema.clone()),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;

let original_fragments = dataset.fragments().to_vec();
dataset
.alter_columns(&[ColumnAlteration::new("b.c".into()).set_nullable(false)])
.await?;
dataset.validate().await?;

assert_eq!(dataset.fragments().as_ref(), &original_fragments);
assert_eq!(
&ArrowSchema::from(dataset.schema()),
&ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"c",
DataType::Int32,
false
)])),
false
)])
);

Ok(())
}

#[rstest]
#[tokio::test]
async fn test_set_not_null_fails_with_nulls(
#[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion,
) -> Result<()> {
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"a",
DataType::Int32,
true,
)]));

let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]))],
)?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(batch)], schema.clone()),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;

let err = dataset
.alter_columns(&[ColumnAlteration::new("a".into()).set_nullable(false)])
.await
.unwrap_err();
assert!(err.to_string().contains("contains NULL values"));
assert_eq!(
&ArrowSchema::from(dataset.schema()),
&ArrowSchema::new(vec![ArrowField::new("a", DataType::Int32, true)])
);

Ok(())
}

#[rstest]
#[tokio::test]
async fn test_set_not_null_fails_with_nulls_nested(
#[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion,
) -> Result<()> {
use arrow_array::{ArrayRef, StructArray};

let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"c",
DataType::Int32,
true,
)])),
false,
)]));

let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(StructArray::from(vec![(
Arc::new(ArrowField::new("c", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as ArrayRef,
)]))],
)?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(batch)], schema.clone()),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;

let err = dataset
.alter_columns(&[ColumnAlteration::new("b.c".into()).set_nullable(false)])
.await
.unwrap_err();
assert!(err.to_string().contains("contains NULL values"));
assert_eq!(
&ArrowSchema::from(dataset.schema()),
&ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"c",
DataType::Int32,
true
)])),
false
)])
);

Ok(())
}

#[rstest]
#[tokio::test]
async fn test_cast_column(
Expand Down
Loading