Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7,359 changes: 1 addition & 7,358 deletions rust/lance/src/dataset.rs

Large diffs are not rendered by default.

118 changes: 118 additions & 0 deletions rust/lance/src/dataset/tests/dataset_common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use std::sync::Arc;

use arrow::array::as_struct_array;
use arrow::compute::concat_batches;
use arrow_array::{
ArrayRef, DictionaryArray, Int32Array, RecordBatch, RecordBatchIterator, StringArray,
StructArray, UInt16Array,
};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use arrow_select::take::take;
use futures::TryStreamExt;
use lance_file::version::LanceFileVersion;
use lance_table::format::WriterVersion;

use crate::dataset::write::WriteParams;
use crate::dataset::WriteMode;
use crate::Dataset;

// Used to validate that futures returned are Send.
pub(super) fn require_send<T: Send>(t: T) -> T {
t
}

pub(super) async fn create_file(
path: &std::path::Path,
mode: WriteMode,
data_storage_version: LanceFileVersion,
) {
let fields = vec![
ArrowField::new("i", DataType::Int32, false),
ArrowField::new(
"dict",
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
false,
),
];
let schema = Arc::new(ArrowSchema::new(fields));
let dict_values = StringArray::from_iter_values(["a", "b", "c", "d", "e"]);
let batches: Vec<RecordBatch> = (0..20)
.map(|i| {
let mut arrays =
vec![Arc::new(Int32Array::from_iter_values(i * 20..(i + 1) * 20)) as ArrayRef];
arrays.push(Arc::new(
DictionaryArray::try_new(
UInt16Array::from_iter_values((0_u16..20_u16).map(|v| v % 5)),
Arc::new(dict_values.clone()),
)
.unwrap(),
));
RecordBatch::try_new(schema.clone(), arrays).unwrap()
})
.collect();
let expected_batches = batches.clone();

let test_uri = path.to_str().unwrap();
let write_params = WriteParams {
max_rows_per_file: 40,
max_rows_per_group: 10,
mode,
data_storage_version: Some(data_storage_version),
..WriteParams::default()
};
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
Dataset::write(reader, test_uri, Some(write_params))
.await
.unwrap();

let actual_ds = Dataset::open(test_uri).await.unwrap();
assert_eq!(actual_ds.version().version, 1);
assert_eq!(
actual_ds.manifest.writer_version,
Some(WriterVersion::default())
);
let actual_schema = ArrowSchema::from(actual_ds.schema());
assert_eq!(&actual_schema, schema.as_ref());

let actual_batches = actual_ds
.scan()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();

// The batch size batches the group size.
// (the v2 writer has no concept of group size)
if data_storage_version == LanceFileVersion::Legacy {
for batch in &actual_batches {
assert_eq!(batch.num_rows(), 10);
}
}

// sort
let actual_batch = concat_batches(&schema, &actual_batches).unwrap();
let idx_arr = actual_batch.column_by_name("i").unwrap();
let sorted_indices = sort_to_indices(idx_arr, None, None).unwrap();
let struct_arr: StructArray = actual_batch.into();
let sorted_arr = take(&struct_arr, &sorted_indices, None).unwrap();

let expected_struct_arr: StructArray =
concat_batches(&schema, &expected_batches).unwrap().into();
assert_eq!(&expected_struct_arr, as_struct_array(sorted_arr.as_ref()));

// Each fragment has different fragment ID
assert_eq!(
actual_ds
.fragments()
.iter()
.map(|f| f.id)
.collect::<Vec<_>>(),
(0..10).collect::<Vec<_>>()
);
}
Loading
Loading