diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 9cf7c824a60..1d7e7e2ef26 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -1784,28 +1784,45 @@ def test_load_scanner_from_fragments(tmp_path: Path): assert scanner.to_table().num_rows == 2 * 100 -def test_merge_data(tmp_path: Path): +def test_merge_data_legacy(tmp_path: Path): tab = pa.table({"a": range(100), "b": range(100)}) - lance.write_dataset(tab, tmp_path / "dataset", mode="append") + lance.write_dataset( + tab, tmp_path / "dataset", mode="append", data_storage_version="legacy" + ) dataset = lance.dataset(tmp_path / "dataset") # rejects partial data for non-nullable types new_tab = pa.table({"a": range(40), "c": range(40)}) - # TODO: this should be ValueError - with pytest.raises( - OSError, match=".+Lance does not yet support nulls for type Int64." - ): + with pytest.raises(OSError, match=r"Join produced null values for type: Int64"): dataset.merge(new_tab, "a") + +def test_merge_data(tmp_path: Path): + tab = pa.table({"a": range(100)}) + lance.write_dataset(tab, tmp_path / "dataset", mode="append") + + dataset = lance.dataset(tmp_path / "dataset") + + # accepts partial data for nullable types + new_tab = pa.table({"a": range(40), "b": range(40)}) + dataset.merge(new_tab, "a") + assert dataset.version == 2 + assert dataset.to_table() == pa.table( + { + "a": range(100), + "b": pa.array(list(range(40)) + [None] * 60), + } + ) + # accepts a full merge new_tab = pa.table({"a": range(100), "c": range(100)}) dataset.merge(new_tab, "a") - assert dataset.version == 2 + assert dataset.version == 3 assert dataset.to_table() == pa.table( { "a": range(100), - "b": range(100), + "b": pa.array(list(range(40)) + [None] * 60), "c": range(100), } ) @@ -1813,11 +1830,11 @@ def test_merge_data(tmp_path: Path): # accepts a partial for string new_tab = pa.table({"a2": range(5), "d": ["a", "b", "c", "d", "e"]}) dataset.merge(new_tab, left_on="a", right_on="a2") - assert dataset.version == 3 + assert dataset.version == 4 expected = pa.table( { "a": range(100), - "b": range(100), + "b": pa.array(list(range(40)) + [None] * 60), "c": range(100), "d": ["a", "b", "c", "d", "e"] + [None] * 95, } diff --git a/rust/lance-core/src/datatypes.rs b/rust/lance-core/src/datatypes.rs index dd56e610f52..704c1c4dbe6 100644 --- a/rust/lance-core/src/datatypes.rs +++ b/rust/lance-core/src/datatypes.rs @@ -427,19 +427,6 @@ impl PartialEq for Dictionary { } } -/// Returns true if Lance supports writing this datatype with nulls. -pub fn lance_supports_nulls(datatype: &DataType) -> bool { - matches!( - datatype, - DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Binary - | DataType::List(_) - | DataType::FixedSizeBinary(_) - | DataType::FixedSizeList(_, _) - ) -} - /// Physical storage mode for blob v2 descriptors (one byte, stored in the packed struct column). #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index 8428cf619b4..d359e84906b 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -1590,7 +1590,9 @@ impl FileFragment { let mut updater = self.updater(Some(&[join_column]), None, None).await?; while let Some(batch) = updater.next().await? { - let batch = joiner.collect(batch[join_column].clone()).await?; + let batch = joiner + .collect(&self.dataset, batch[join_column].clone()) + .await?; updater.update(batch).await?; } diff --git a/rust/lance/src/dataset/hash_joiner.rs b/rust/lance/src/dataset/hash_joiner.rs index e9f8c14d9bb..7952c41d78f 100644 --- a/rust/lance/src/dataset/hash_joiner.rs +++ b/rust/lance/src/dataset/hash_joiner.rs @@ -5,6 +5,7 @@ use std::sync::Arc; +use crate::{Dataset, Error, Result}; use arrow_array::ArrayRef; use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchReader}; use arrow_row::{OwnedRow, RowConverter, Rows, SortField}; @@ -16,9 +17,6 @@ use lance_core::utils::tokio::get_num_compute_intensive_cpus; use snafu::location; use tokio::task; -use crate::datatypes::lance_supports_nulls; -use crate::{Dataset, Error, Result}; - /// `HashJoiner` does hash join on two datasets. pub struct HashJoiner { index_map: ReadOnlyView, @@ -133,7 +131,11 @@ impl HashJoiner { /// Collecting the data using the index column from left table. /// /// Will run in parallel over columns using all available cores. - pub(super) async fn collect(&self, index_column: ArrayRef) -> Result { + pub(super) async fn collect( + &self, + dataset: &Dataset, + index_column: ArrayRef, + ) -> Result { if index_column.data_type() != &self.index_type { return Err(Error::invalid_input( format!( @@ -180,29 +182,18 @@ impl HashJoiner { async move { let task_result = task::spawn_blocking(move || { let array_refs = arrays.iter().map(|x| x.as_ref()).collect::>(); - interleave(array_refs.as_ref(), indices.as_ref()) - .map_err(|err| Error::invalid_input( - format!("HashJoiner: {}", err), - location!(), - )) + interleave(array_refs.as_ref(), indices.as_ref()).map_err(|err| { + Error::invalid_input(format!("HashJoiner: {}", err), location!()) + }) }) .await; match task_result { Ok(Ok(array)) => { - if array.null_count() > 0 && !lance_supports_nulls(array.data_type()) { - return Err(Error::invalid_input(format!( - "Found rows on LHS that do not match any rows on RHS. Lance would need to write \ - nulls on the RHS, but Lance does not yet support nulls for type {:?}.", - array.data_type() - ), location!())); - } + Self::check_lance_support_null(&array, dataset)?; Ok(array) - }, + } Ok(Err(err)) => Err(err), - Err(err) => Err(Error::io( - format!("HashJoiner: {}", err), - location!(), - )), + Err(err) => Err(Error::io(format!("HashJoiner: {}", err), location!())), } } }) @@ -213,6 +204,27 @@ impl HashJoiner { Ok(RecordBatch::try_new(self.batches[0].schema(), columns)?) } + pub fn check_lance_support_null(array: &ArrayRef, dataset: &Dataset) -> Result<()> { + if array.null_count() > 0 && !dataset.lance_supports_nulls(array.data_type()) { + return Err(Error::invalid_input( + format!( + "Join produced null values for type: {:?}, but storing \ + nulls for this data type is not supported by the \ + dataset's current Lance file format version: {:?}. This \ + can be caused by an explicit null in the new data.", + array.data_type(), + dataset + .manifest() + .data_storage_format + .lance_file_version() + .unwrap() + ), + location!(), + )); + } + Ok(()) + } + /// Collecting the data using the index column from left table, /// invalid join column values in left table will be filled with origin values in left table /// @@ -271,25 +283,7 @@ impl HashJoiner { .await; match task_result { Ok(Ok(array)) => { - if array.null_count() > 0 - && !dataset.lance_supports_nulls(array.data_type()) - { - return Err(Error::invalid_input( - format!( - "Join produced null values for type: {:?}, but storing \ - nulls for this data type is not supported by the \ - dataset's current Lance file format version: {:?}. This \ - can be caused by an explicit null in the new data.", - array.data_type(), - dataset - .manifest() - .data_storage_format - .lance_file_version() - .unwrap() - ), - location!(), - )); - } + Self::check_lance_support_null(&array, dataset)?; Ok(array) } Ok(Err(err)) => Err(err), @@ -311,9 +305,18 @@ impl HashJoiner { mod tests { use super::*; - use arrow_array::{Int32Array, RecordBatchIterator, StringArray, UInt32Array}; use arrow_schema::{DataType, Field, Schema}; + use lance_core::utils::tempfile::TempDir; + + async fn create_dataset() -> Dataset { + let uri = TempDir::default().path_str(); + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let batches = RecordBatchIterator::new(std::iter::empty().map(Ok), schema.clone()); + Dataset::write(batches, &uri, None).await.unwrap(); + + Dataset::open(&uri).await.unwrap() + } #[tokio::test] async fn test_joiner_collect() { @@ -343,6 +346,8 @@ mod tests { )); let joiner = HashJoiner::try_new(batches, "i").await.unwrap(); + let dataset = create_dataset().await; + let indices = Arc::new(Int32Array::from_iter(&[ Some(15), None, @@ -353,7 +358,7 @@ mod tests { Some(22), Some(11111), // not found ])); - let results = joiner.collect(indices).await.unwrap(); + let results = joiner.collect(&dataset, indices).await.unwrap(); assert_eq!( results.column_by_name("s").unwrap().as_ref(), @@ -394,9 +399,11 @@ mod tests { let joiner = HashJoiner::try_new(batches, "i").await.unwrap(); + let dataset = create_dataset().await; + // Wrong type: was Int32, passing UInt32. let indices = Arc::new(UInt32Array::from_iter(&[Some(15)])); - let result = joiner.collect(indices).await; + let result = joiner.collect(&dataset, indices).await; assert!(result.is_err()); assert!(result .unwrap_err()