Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/lance-datafusion/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ pub fn safe_coerce_scalar(value: &ScalarValue, ty: &DataType) -> Option<ScalarVa
// See above warning about lossy float conversion
DataType::Float32 => val.map(|v| ScalarValue::Float32(Some(v as f32))),
DataType::Float64 => val.map(|v| ScalarValue::Float64(Some(v as f64))),
DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => value.cast_to(ty).ok(),
_ => None,
},
ScalarValue::UInt8(val) => match ty {
Expand Down
5 changes: 3 additions & 2 deletions rust/lance-datagen/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2083,7 +2083,8 @@ pub mod array {
use arrow_array::{
ArrowNativeTypeOp, BooleanArray, Date32Array, Date64Array, Time32MillisecondArray,
Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
TimestampMicrosecondArray, TimestampNanosecondArray, TimestampSecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray,
};
use arrow_schema::{IntervalUnit, TimeUnit};
use chrono::Utc;
Expand Down Expand Up @@ -2518,7 +2519,7 @@ pub mod array {
))
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
Box::new(FnGen::<i64, TimestampMicrosecondArray, _>::new_known_size(
Box::new(FnGen::<i64, TimestampMillisecondArray, _>::new_known_size(
data_type, sample_fn, 1, width,
))
}
Expand Down
339 changes: 319 additions & 20 deletions rust/lance/tests/query/primitives.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use std::sync::Arc;

use arrow::datatypes::*;
use arrow_array::RecordBatch;
use arrow_array::{
ArrayRef, BinaryArray, BinaryViewArray, Float32Array, Float64Array, Int32Array,
LargeBinaryArray, LargeStringArray, RecordBatch, StringArray, StringViewArray,
};
use arrow_schema::DataType;
use lance::Dataset;

Expand Down Expand Up @@ -51,42 +56,336 @@ async fn test_query_bool() {
#[case::uint32(DataType::UInt32)]
#[case::uint64(DataType::UInt64)]
async fn test_query_integer(#[case] data_type: DataType) {
let value_generator = match data_type {
DataType::Int8 => array::rand_primitive::<Int8Type>(data_type),
DataType::Int16 => array::rand_primitive::<Int16Type>(data_type),
DataType::Int32 => array::rand_primitive::<Int32Type>(data_type),
DataType::Int64 => array::rand_primitive::<Int64Type>(data_type),
DataType::UInt8 => array::rand_primitive::<UInt8Type>(data_type),
DataType::UInt16 => array::rand_primitive::<UInt16Type>(data_type),
DataType::UInt32 => array::rand_primitive::<UInt32Type>(data_type),
DataType::UInt64 => array::rand_primitive::<UInt64Type>(data_type),
let batch = gen_batch()
.col("id", array::step::<Int32Type>())
.col("value", array::rand_type(&data_type).with_random_nulls(0.1))
.into_batch_rows(RowCount::from(60))
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::Bitmap),
Some(IndexType::BTree),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value > 20").await;
test_filter(&original, &ds, "NOT (value > 20)").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await
}

#[tokio::test]
#[rstest::rstest]
#[case::float32(DataType::Float32)]
#[case::float64(DataType::Float64)]
async fn test_query_float(#[case] data_type: DataType) {
let batch = gen_batch()
.col("id", array::step::<Int32Type>())
.col("value", array::rand_type(&data_type).with_random_nulls(0.1))
.into_batch_rows(RowCount::from(60))
.unwrap();
DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::BTree),
Some(IndexType::Bitmap),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value > 0.5").await;
test_filter(&original, &ds, "NOT (value > 0.5)").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
test_filter(&original, &ds, "isnan(value)").await;
test_filter(&original, &ds, "not isnan(value)").await;
})
.await
}

#[tokio::test]
#[rstest::rstest]
#[case::float32(DataType::Float32)]
#[case::float64(DataType::Float64)]
async fn test_query_float_special_values(#[case] data_type: DataType) {
let value_array: Arc<dyn arrow_array::Array> = match data_type {
DataType::Float32 => Arc::new(Float32Array::from(vec![
Some(0.0_f32),
Some(-0.0_f32),
Some(f32::INFINITY),
Some(f32::NEG_INFINITY),
Some(f32::NAN),
Some(1.0_f32),
Some(-1.0_f32),
Some(f32::MIN),
Some(f32::MAX),
None,
])),
DataType::Float64 => Arc::new(Float64Array::from(vec![
Some(0.0_f64),
Some(-0.0_f64),
Some(f64::INFINITY),
Some(f64::NEG_INFINITY),
Some(f64::NAN),
Some(1.0_f64),
Some(-1.0_f64),
Some(f64::MIN),
Some(f64::MAX),
None,
])),
_ => unreachable!(),
};

let id_array = Arc::new(Int32Array::from((0..10).collect::<Vec<i32>>()));

let batch =
RecordBatch::try_from_iter(vec![("id", id_array as ArrayRef), ("value", value_array)])
.unwrap();

DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::BTree),
Some(IndexType::Bitmap),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value > 0.0").await;
test_filter(&original, &ds, "value < 0.0").await;
test_filter(&original, &ds, "value = 0.0").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
test_filter(&original, &ds, "isnan(value)").await;
test_filter(&original, &ds, "not isnan(value)").await;
})
.await
}

#[tokio::test]
#[rstest::rstest]
#[case::date32(DataType::Date32)]
#[case::date64(DataType::Date64)]
async fn test_query_date(#[case] data_type: DataType) {
let batch = gen_batch()
.col("id", array::step::<Int32Type>())
.col("value", value_generator.with_random_nulls(0.1))
.col("value", array::rand_type(&data_type).with_random_nulls(0.1))
.into_batch_rows(RowCount::from(60))
.unwrap();

DatasetTestCases::from_data(batch)
.with_index_types(
"value",
// TODO: add zone map and bloom filter once we fix https://github.com/lancedb/lance/issues/4758
[None, Some(IndexType::Bitmap), Some(IndexType::BTree)],
[
None,
Some(IndexType::Bitmap),
Some(IndexType::BTree),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value > 20").await;
test_filter(&original, &ds, "NOT (value > 20)").await;
test_filter(&original, &ds, "value < current_date()").await;
test_filter(&original, &ds, "value > DATE '2024-01-01'").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await
}

#[tokio::test]
#[rstest::rstest]
#[case::timestamp_second(DataType::Timestamp(TimeUnit::Second, None))]
#[case::timestamp_millisecond(DataType::Timestamp(TimeUnit::Millisecond, None))]
#[case::timestamp_microsecond(DataType::Timestamp(TimeUnit::Microsecond, None))]
#[case::timestamp_nanosecond(DataType::Timestamp(TimeUnit::Nanosecond, None))]
async fn test_query_timestamp(#[case] data_type: DataType) {
let batch = gen_batch()
.col("id", array::step::<Int32Type>())
.col("value", array::rand_type(&data_type).with_random_nulls(0.1))
.into_batch_rows(RowCount::from(60))
.unwrap();

DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::BTree),
Some(IndexType::Bitmap),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value < current_timestamp()").await;
test_filter(&original, &ds, "value > TIMESTAMP '2024-01-01 00:00:00'").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await
}

#[tokio::test]
#[rstest::rstest]
#[case::utf8(DataType::Utf8)]
#[case::large_utf8(DataType::LargeUtf8)]
// #[case::string_view(DataType::Utf8View)] // TODO: https://github.com/lancedb/lance/issues/5172
async fn test_query_string(#[case] data_type: DataType) {
// Create arrays that include empty strings
let string_values = vec![
Some("hello"),
Some("world"),
Some(""),
Some("test"),
Some("data"),
Some(""),
None,
Some("apple"),
Some("zebra"),
Some(""),
];

let value_array: ArrayRef = match data_type {
DataType::Utf8 => Arc::new(StringArray::from(string_values.clone())),
DataType::LargeUtf8 => Arc::new(LargeStringArray::from(string_values.clone())),
DataType::Utf8View => Arc::new(StringViewArray::from(string_values.clone())),
_ => unreachable!(),
};

let id_array = Arc::new(Int32Array::from((0..10).collect::<Vec<i32>>()));

let batch =
RecordBatch::try_from_iter(vec![("id", id_array as ArrayRef), ("value", value_array)])
.unwrap();

DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::Bitmap),
Some(IndexType::BTree),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value = 'hello'").await;
test_filter(&original, &ds, "value != 'hello'").await;
test_filter(&original, &ds, "value = ''").await;
test_filter(&original, &ds, "value > 'hello'").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await
}

#[tokio::test]
#[rstest::rstest]
#[case::binary(DataType::Binary)]
#[case::large_binary(DataType::LargeBinary)]
// #[case::binary_view(DataType::BinaryView)] // TODO: https://github.com/lancedb/lance/issues/5172
async fn test_query_binary(#[case] data_type: DataType) {
// Create arrays that include empty binary
let binary_values = vec![
Some(b"hello".as_slice()),
Some(b"world".as_slice()),
Some(b"".as_slice()),
Some(b"test".as_slice()),
Some(b"data".as_slice()),
Some(b"".as_slice()),
None,
Some(b"apple".as_slice()),
Some(b"zebra".as_slice()),
Some(b"".as_slice()),
];

let value_array: ArrayRef = match data_type {
DataType::Binary => Arc::new(BinaryArray::from(binary_values.clone())),
DataType::LargeBinary => Arc::new(LargeBinaryArray::from(binary_values.clone())),
DataType::BinaryView => Arc::new(BinaryViewArray::from(binary_values.clone())),
_ => unreachable!(),
};

let id_array = Arc::new(Int32Array::from((0..10).collect::<Vec<i32>>()));

let batch =
RecordBatch::try_from_iter(vec![("id", id_array as ArrayRef), ("value", value_array)])
.unwrap();

DatasetTestCases::from_data(batch)
.with_index_types(
"value",
[
None,
Some(IndexType::Bitmap),
Some(IndexType::BTree),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value = X'68656C6C6F'").await; // 'hello' in hex
test_filter(&original, &ds, "value != X'68656C6C6F'").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await
}

// TODO: floats (including NaN, +/-Inf, +/-0)
// TODO: decimals
// TODO: binary
// TODO: strings (including largestrings and view)
// TODO: timestamps
#[tokio::test]
#[rstest::rstest]
// TODO: Add Decimal32 and Decimal64 https://github.com/lancedb/lance/issues/5174
#[case::decimal128(DataType::Decimal128(38, 10))]
#[case::decimal256(DataType::Decimal256(76, 20))]
async fn test_query_decimal(#[case] data_type: DataType) {
let batch = gen_batch()
.col("id", array::step::<Int32Type>())
.col("value", array::rand_type(&data_type).with_random_nulls(0.1))
.into_batch_rows(RowCount::from(60))
.unwrap();

DatasetTestCases::from_data(batch)
.with_index_types(
"value",
// NOTE: BloomFilter not supported for decimals
[None, Some(IndexType::Bitmap), Some(IndexType::BTree)],
)
.run(|ds: Dataset, original: RecordBatch| async move {
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
test_filter(&original, &ds, "value > 0").await;
test_filter(&original, &ds, "value < 0").await;
test_filter(&original, &ds, "value is null").await;
test_filter(&original, &ds, "value is not null").await;
})
.await
}
Loading