diff --git a/rust/lance/src/dataset/schema_evolution.rs b/rust/lance/src/dataset/schema_evolution.rs index da5ea5b1688..86752a28b94 100644 --- a/rust/lance/src/dataset/schema_evolution.rs +++ b/rust/lance/src/dataset/schema_evolution.rs @@ -11,7 +11,7 @@ use super::{ use crate::{io::exec::Planner, Error, Result}; use arrow::compute::can_cast_types; use arrow::compute::CastOptions; -use arrow_array::{RecordBatch, RecordBatchReader}; +use arrow_array::{Array, RecordBatch, RecordBatchReader}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use datafusion::execution::SendableRecordBatchStream; use futures::stream::{StreamExt, TryStreamExt}; @@ -29,6 +29,49 @@ use optimize::{ ChainedNewColumnTransformOptimizer, NewColumnTransformOptimizer, SqlToAllNullsOptimizer, }; +async fn validate_no_nulls_before_making_non_nullable(dataset: &Dataset, path: &str) -> Result<()> { + let field = dataset.schema().field(path).ok_or_else(|| { + Error::invalid_input( + format!("Column \"{}\" does not exist in the dataset", path), + location!(), + ) + })?; + + if !field.nullable { + return Ok(()); + } + + let mut scanner = dataset.scan(); + scanner.project(&[path])?; + let mut stream = scanner.try_into_stream().await?; + while let Some(batch) = stream.try_next().await? { + // `path` can be a nested path (e.g. "b.c") which will not be found by + // `RecordBatch::column_by_name`. We project exactly one column and validate it directly. + if batch.num_columns() != 1 { + return Err(Error::Internal { + message: format!( + "Expected exactly one column in validation scan for {}, got {}", + path, + batch.num_columns() + ), + location: location!(), + }); + } + let col = batch.column(0); + if col.null_count() > 0 { + return Err(Error::invalid_input( + format!( + "Column \"{}\" contains NULL values and cannot be made non-nullable", + path + ), + location!(), + )); + } + } + + Ok(()) +} + #[derive(Debug, Clone, PartialEq)] pub struct BatchInfo { pub fragment_id: u32, @@ -522,8 +565,8 @@ pub(super) async fn alter_columns( dataset: &mut Dataset, alterations: &[ColumnAlteration], ) -> Result<()> { - // Validate we aren't making nullable columns non-nullable and that all - // the referenced columns actually exist. + // Validate referenced columns exist and enforce NOT NULL when tightening + // a column from nullable to non-nullable. let mut new_schema = dataset.schema().clone(); // Mapping of old to new fields that need to be casted. @@ -543,16 +586,8 @@ pub(super) async fn alter_columns( })?; if let Some(nullable) = alteration.nullable { - // TODO: in the future, we could check the values of the column to see if - // they are all non-null and thus the column could be made non-nullable. if field_src.nullable && !nullable { - return Err(Error::invalid_input( - format!( - "Column \"{}\" is already nullable and thus cannot be made non-nullable", - alteration.path - ), - location!(), - )); + validate_no_nulls_before_making_non_nullable(dataset, &alteration.path).await?; } } @@ -1563,6 +1598,207 @@ mod test { Ok(()) } + #[rstest] + #[tokio::test] + async fn test_set_not_null_succeeds( + #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] + data_storage_version: LanceFileVersion, + ) -> Result<()> { + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "a", + DataType::Int32, + true, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values([1, 2, 3]))], + )?; + let test_dir = TempStrDir::default(); + let test_uri = &test_dir; + let mut dataset = Dataset::write( + RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), + test_uri, + Some(WriteParams { + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await?; + + let original_fragments = dataset.fragments().to_vec(); + dataset + .alter_columns(&[ColumnAlteration::new("a".into()).set_nullable(false)]) + .await?; + dataset.validate().await?; + + assert_eq!(dataset.manifest.version, 2); + assert_eq!(dataset.fragments().as_ref(), &original_fragments); + assert_eq!( + &ArrowSchema::from(dataset.schema()), + &ArrowSchema::new(vec![ArrowField::new("a", DataType::Int32, false)]) + ); + + Ok(()) + } + + #[rstest] + #[tokio::test] + async fn test_set_not_null_succeeds_nested( + #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] + data_storage_version: LanceFileVersion, + ) -> Result<()> { + use arrow_array::{ArrayRef, StructArray}; + + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "b", + DataType::Struct(ArrowFields::from(vec![ArrowField::new( + "c", + DataType::Int32, + true, + )])), + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(StructArray::from(vec![( + Arc::new(ArrowField::new("c", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + )]))], + )?; + let test_dir = TempStrDir::default(); + let test_uri = &test_dir; + let mut dataset = Dataset::write( + RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), + test_uri, + Some(WriteParams { + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await?; + + let original_fragments = dataset.fragments().to_vec(); + dataset + .alter_columns(&[ColumnAlteration::new("b.c".into()).set_nullable(false)]) + .await?; + dataset.validate().await?; + + assert_eq!(dataset.fragments().as_ref(), &original_fragments); + assert_eq!( + &ArrowSchema::from(dataset.schema()), + &ArrowSchema::new(vec![ArrowField::new( + "b", + DataType::Struct(ArrowFields::from(vec![ArrowField::new( + "c", + DataType::Int32, + false + )])), + false + )]) + ); + + Ok(()) + } + + #[rstest] + #[tokio::test] + async fn test_set_not_null_fails_with_nulls( + #[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, + ) -> Result<()> { + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "a", + DataType::Int32, + true, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]))], + )?; + let test_dir = TempStrDir::default(); + let test_uri = &test_dir; + let mut dataset = Dataset::write( + RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), + test_uri, + Some(WriteParams { + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await?; + + let err = dataset + .alter_columns(&[ColumnAlteration::new("a".into()).set_nullable(false)]) + .await + .unwrap_err(); + assert!(err.to_string().contains("contains NULL values")); + assert_eq!( + &ArrowSchema::from(dataset.schema()), + &ArrowSchema::new(vec![ArrowField::new("a", DataType::Int32, true)]) + ); + + Ok(()) + } + + #[rstest] + #[tokio::test] + async fn test_set_not_null_fails_with_nulls_nested( + #[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, + ) -> Result<()> { + use arrow_array::{ArrayRef, StructArray}; + + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "b", + DataType::Struct(ArrowFields::from(vec![ArrowField::new( + "c", + DataType::Int32, + true, + )])), + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(StructArray::from(vec![( + Arc::new(ArrowField::new("c", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as ArrayRef, + )]))], + )?; + let test_dir = TempStrDir::default(); + let test_uri = &test_dir; + let mut dataset = Dataset::write( + RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), + test_uri, + Some(WriteParams { + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await?; + + let err = dataset + .alter_columns(&[ColumnAlteration::new("b.c".into()).set_nullable(false)]) + .await + .unwrap_err(); + assert!(err.to_string().contains("contains NULL values")); + assert_eq!( + &ArrowSchema::from(dataset.schema()), + &ArrowSchema::new(vec![ArrowField::new( + "b", + DataType::Struct(ArrowFields::from(vec![ArrowField::new( + "c", + DataType::Int32, + true + )])), + false + )]) + ); + + Ok(()) + } + #[rstest] #[tokio::test] async fn test_cast_column(