Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,40 +1784,57 @@ def test_load_scanner_from_fragments(tmp_path: Path):
assert scanner.to_table().num_rows == 2 * 100


def test_merge_data(tmp_path: Path):
def test_merge_data_legacy(tmp_path: Path):
tab = pa.table({"a": range(100), "b": range(100)})
lance.write_dataset(tab, tmp_path / "dataset", mode="append")
lance.write_dataset(
tab, tmp_path / "dataset", mode="append", data_storage_version="legacy"
)

dataset = lance.dataset(tmp_path / "dataset")

# rejects partial data for non-nullable types
new_tab = pa.table({"a": range(40), "c": range(40)})
# TODO: this should be ValueError
with pytest.raises(
OSError, match=".+Lance does not yet support nulls for type Int64."
):
with pytest.raises(OSError, match=r"Join produced null values for type: Int64"):
dataset.merge(new_tab, "a")


def test_merge_data(tmp_path: Path):
tab = pa.table({"a": range(100)})
lance.write_dataset(tab, tmp_path / "dataset", mode="append")

dataset = lance.dataset(tmp_path / "dataset")

# accepts partial data for nullable types
new_tab = pa.table({"a": range(40), "b": range(40)})
dataset.merge(new_tab, "a")
assert dataset.version == 2
assert dataset.to_table() == pa.table(
{
"a": range(100),
"b": pa.array(list(range(40)) + [None] * 60),
}
)

# accepts a full merge
new_tab = pa.table({"a": range(100), "c": range(100)})
dataset.merge(new_tab, "a")
assert dataset.version == 2
assert dataset.version == 3
assert dataset.to_table() == pa.table(
{
"a": range(100),
"b": range(100),
"b": pa.array(list(range(40)) + [None] * 60),
"c": range(100),
}
)

# accepts a partial for string
new_tab = pa.table({"a2": range(5), "d": ["a", "b", "c", "d", "e"]})
dataset.merge(new_tab, left_on="a", right_on="a2")
assert dataset.version == 3
assert dataset.version == 4
expected = pa.table(
{
"a": range(100),
"b": range(100),
"b": pa.array(list(range(40)) + [None] * 60),
"c": range(100),
"d": ["a", "b", "c", "d", "e"] + [None] * 95,
}
Expand Down
13 changes: 0 additions & 13 deletions rust/lance-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,19 +427,6 @@ impl PartialEq for Dictionary {
}
}

/// Returns true if Lance supports writing this datatype with nulls.
pub fn lance_supports_nulls(datatype: &DataType) -> bool {
matches!(
datatype,
DataType::Utf8
| DataType::LargeUtf8
| DataType::Binary
| DataType::List(_)
| DataType::FixedSizeBinary(_)
| DataType::FixedSizeList(_, _)
)
}

/// Physical storage mode for blob v2 descriptors (one byte, stored in the packed struct column).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
Expand Down
4 changes: 3 additions & 1 deletion rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1590,7 +1590,9 @@ impl FileFragment {
let mut updater = self.updater(Some(&[join_column]), None, None).await?;

while let Some(batch) = updater.next().await? {
let batch = joiner.collect(batch[join_column].clone()).await?;
let batch = joiner
.collect(&self.dataset, batch[join_column].clone())
.await?;
updater.update(batch).await?;
}

Expand Down
93 changes: 50 additions & 43 deletions rust/lance/src/dataset/hash_joiner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

use std::sync::Arc;

use crate::{Dataset, Error, Result};
use arrow_array::ArrayRef;
use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchReader};
use arrow_row::{OwnedRow, RowConverter, Rows, SortField};
Expand All @@ -16,9 +17,6 @@ use lance_core::utils::tokio::get_num_compute_intensive_cpus;
use snafu::location;
use tokio::task;

use crate::datatypes::lance_supports_nulls;
use crate::{Dataset, Error, Result};

/// `HashJoiner` does hash join on two datasets.
pub struct HashJoiner {
index_map: ReadOnlyView<OwnedRow, (usize, usize)>,
Expand Down Expand Up @@ -133,7 +131,11 @@ impl HashJoiner {
/// Collecting the data using the index column from left table.
///
/// Will run in parallel over columns using all available cores.
pub(super) async fn collect(&self, index_column: ArrayRef) -> Result<RecordBatch> {
pub(super) async fn collect(
&self,
dataset: &Dataset,
index_column: ArrayRef,
) -> Result<RecordBatch> {
if index_column.data_type() != &self.index_type {
return Err(Error::invalid_input(
format!(
Expand Down Expand Up @@ -180,29 +182,18 @@ impl HashJoiner {
async move {
let task_result = task::spawn_blocking(move || {
let array_refs = arrays.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
interleave(array_refs.as_ref(), indices.as_ref())
.map_err(|err| Error::invalid_input(
format!("HashJoiner: {}", err),
location!(),
))
interleave(array_refs.as_ref(), indices.as_ref()).map_err(|err| {
Error::invalid_input(format!("HashJoiner: {}", err), location!())
})
})
.await;
match task_result {
Ok(Ok(array)) => {
if array.null_count() > 0 && !lance_supports_nulls(array.data_type()) {
return Err(Error::invalid_input(format!(
"Found rows on LHS that do not match any rows on RHS. Lance would need to write \
nulls on the RHS, but Lance does not yet support nulls for type {:?}.",
array.data_type()
), location!()));
}
Self::check_lance_support_null(&array, dataset)?;
Ok(array)
},
}
Ok(Err(err)) => Err(err),
Err(err) => Err(Error::io(
format!("HashJoiner: {}", err),
location!(),
)),
Err(err) => Err(Error::io(format!("HashJoiner: {}", err), location!())),
}
}
})
Expand All @@ -213,6 +204,27 @@ impl HashJoiner {
Ok(RecordBatch::try_new(self.batches[0].schema(), columns)?)
}

pub fn check_lance_support_null(array: &ArrayRef, dataset: &Dataset) -> Result<()> {
if array.null_count() > 0 && !dataset.lance_supports_nulls(array.data_type()) {
return Err(Error::invalid_input(
format!(
"Join produced null values for type: {:?}, but storing \
nulls for this data type is not supported by the \
dataset's current Lance file format version: {:?}. This \
can be caused by an explicit null in the new data.",
array.data_type(),
dataset
.manifest()
.data_storage_format
.lance_file_version()
.unwrap()
),
location!(),
));
}
Ok(())
}

/// Collecting the data using the index column from left table,
/// invalid join column values in left table will be filled with origin values in left table
///
Expand Down Expand Up @@ -271,25 +283,7 @@ impl HashJoiner {
.await;
match task_result {
Ok(Ok(array)) => {
if array.null_count() > 0
&& !dataset.lance_supports_nulls(array.data_type())
{
return Err(Error::invalid_input(
format!(
"Join produced null values for type: {:?}, but storing \
nulls for this data type is not supported by the \
dataset's current Lance file format version: {:?}. This \
can be caused by an explicit null in the new data.",
array.data_type(),
dataset
.manifest()
.data_storage_format
.lance_file_version()
.unwrap()
),
location!(),
));
}
Self::check_lance_support_null(&array, dataset)?;
Ok(array)
}
Ok(Err(err)) => Err(err),
Expand All @@ -311,9 +305,18 @@ impl HashJoiner {
mod tests {

use super::*;

use arrow_array::{Int32Array, RecordBatchIterator, StringArray, UInt32Array};
use arrow_schema::{DataType, Field, Schema};
use lance_core::utils::tempfile::TempDir;

async fn create_dataset() -> Dataset {
let uri = TempDir::default().path_str();
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)]));
let batches = RecordBatchIterator::new(std::iter::empty().map(Ok), schema.clone());
Dataset::write(batches, &uri, None).await.unwrap();

Dataset::open(&uri).await.unwrap()
}

#[tokio::test]
async fn test_joiner_collect() {
Expand Down Expand Up @@ -343,6 +346,8 @@ mod tests {
));
let joiner = HashJoiner::try_new(batches, "i").await.unwrap();

let dataset = create_dataset().await;

let indices = Arc::new(Int32Array::from_iter(&[
Some(15),
None,
Expand All @@ -353,7 +358,7 @@ mod tests {
Some(22),
Some(11111), // not found
]));
let results = joiner.collect(indices).await.unwrap();
let results = joiner.collect(&dataset, indices).await.unwrap();

assert_eq!(
results.column_by_name("s").unwrap().as_ref(),
Expand Down Expand Up @@ -394,9 +399,11 @@ mod tests {

let joiner = HashJoiner::try_new(batches, "i").await.unwrap();

let dataset = create_dataset().await;

// Wrong type: was Int32, passing UInt32.
let indices = Arc::new(UInt32Array::from_iter(&[Some(15)]));
let result = joiner.collect(indices).await;
let result = joiner.collect(&dataset, indices).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
Expand Down
Loading