Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/python/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_sql_predicates(dataset):
("int >= 50", 50),
("int = 50", 1),
("int != 50", 99),
("int BETWEEN 50 AND 60", 11),
("float < 30.0", 45),
("str = 'aa'", 16),
("str in ('aa', 'bb')", 26),
Expand Down
19 changes: 19 additions & 0 deletions python/python/tests/test_scalar_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,25 @@ def test_indexed_scalar_scan(indexed_dataset: lance.LanceDataset, data_table: pa
assert actual_price == expected_price


def test_indexed_between(tmp_path):
dataset = lance.write_dataset(pa.table({"val": range(100)}), tmp_path)
dataset.create_scalar_index("val", index_type="BTREE")

scanner = dataset.scanner(filter="val BETWEEN 10 AND 20", prefilter=True)

assert "MaterializeIndex" in scanner.explain_plan()

actual_data = scanner.to_table()
assert actual_data.num_rows == 11

scanner = dataset.scanner(filter="val >= 10 AND val <= 20", prefilter=True)

assert "MaterializeIndex" in scanner.explain_plan()

actual_data = scanner.to_table()
assert actual_data.num_rows == 11


def test_temporal_index(tmp_path):
# Timestamps
now = datetime.now()
Expand Down
113 changes: 112 additions & 1 deletion rust/lance-datafusion/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use datafusion::sql::sqlparser::ast::{
};
use datafusion::{
common::Column,
logical_expr::{col, BinaryExpr, Like, Operator},
logical_expr::{col, Between, BinaryExpr, Like, Operator},
physical_expr::execution_props::ExecutionProps,
physical_plan::PhysicalExpr,
prelude::Expr,
Expand Down Expand Up @@ -746,6 +746,25 @@ impl Planner {
let field_access_expr = RawFieldAccessExpr { expr, field_access };
self.plan_field_access(field_access_expr)
}
SQLExpr::Between {
expr,
negated,
low,
high,
} => {
// Parse the main expression and bounds
let expr = self.parse_sql_expr(expr)?;
let low = self.parse_sql_expr(low)?;
let high = self.parse_sql_expr(high)?;

let between = Expr::Between(Between::new(
Box::new(expr),
*negated,
Box::new(low),
Box::new(high),
));
Ok(between)
}
_ => Err(Error::invalid_input(
format!("Expression '{expr}' is not supported SQL in lance"),
location!(),
Expand Down Expand Up @@ -1463,6 +1482,98 @@ mod tests {
}
}

#[test]
fn test_sql_between() {
use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray};
use arrow_schema::{DataType, Field, Schema, TimeUnit};
use std::sync::Arc;

let schema = Arc::new(Schema::new(vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Float64, false),
Field::new(
"ts",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
]));

let planner = Planner::new(schema.clone());

// Test integer BETWEEN
let expr = planner
.parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
.unwrap();
let physical_expr = planner.create_physical_expr(&expr).unwrap();

// Create timestamp array with values representing:
// 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds)
let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00
let ts_array = TimestampMicrosecondArray::from_iter_values(
(0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart
);

let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef,
Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),
Arc::new(ts_array),
],
)
.unwrap();

let predicates = physical_expr.evaluate(&batch).unwrap();
assert_eq!(
predicates.into_array(0).unwrap().as_ref(),
&BooleanArray::from(vec![
false, false, false, true, true, true, true, true, false, false
])
);

// Test NOT BETWEEN
let expr = planner
.parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)")
.unwrap();
let physical_expr = planner.create_physical_expr(&expr).unwrap();

let predicates = physical_expr.evaluate(&batch).unwrap();
assert_eq!(
predicates.into_array(0).unwrap().as_ref(),
&BooleanArray::from(vec![
true, true, true, false, false, false, false, false, true, true
])
);

// Test floating point BETWEEN
let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap();
let physical_expr = planner.create_physical_expr(&expr).unwrap();

let predicates = physical_expr.evaluate(&batch).unwrap();
assert_eq!(
predicates.into_array(0).unwrap().as_ref(),
&BooleanArray::from(vec![
false, false, false, true, true, true, true, false, false, false
])
);

// Test timestamp BETWEEN
let expr = planner
.parse_filter(
"ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'",
)
.unwrap();
let physical_expr = planner.create_physical_expr(&expr).unwrap();

let predicates = physical_expr.evaluate(&batch).unwrap();
assert_eq!(
predicates.into_array(0).unwrap().as_ref(),
&BooleanArray::from(vec![
false, false, false, true, true, true, true, true, false, false
])
);
}

#[test]
fn test_sql_comparison() {
// Create a batch with all data types
Expand Down
120 changes: 117 additions & 3 deletions rust/lance-index/src/scalar/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use datafusion_expr::{
use futures::join;
use lance_core::{utils::mask::RowIdMask, Result};
use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner};
use log::warn;
use tracing::instrument;

use super::{AnyQuery, LabelListQuery, SargableQuery, ScalarIndex};
Expand Down Expand Up @@ -564,16 +565,85 @@ fn visit_comparison(
let scalar = maybe_scalar(&expr.right, col_type)?;
query_parser.visit_comparison(column, scalar, &expr.op)
} else {
let (column, col_type, query_parser) = maybe_indexed_column(&expr.right, index_info)?;
let scalar = maybe_scalar(&expr.left, col_type)?;
query_parser.visit_comparison(column, scalar, &expr.op)
// Datafusion's query simplifier will canonicalize expressions and so we shouldn't reach this case. If, for some reason, we
// do reach this case we can handle it in the future by inverting expr.op and swapping the left and right sides
warn!("Unexpected comparison encountered (DF simplifier should have removed this case). Scalar indices will not be applied");
None
}
}

fn maybe_between(expr: &BinaryExpr) -> Option<Between> {
let left_comparison = match expr.left.as_ref() {
Expr::BinaryExpr(binary_expr) => Some(binary_expr),
_ => None,
}?;
let right_comparison = match expr.right.as_ref() {
Expr::BinaryExpr(binary_expr) => Some(binary_expr),
_ => None,
}?;

match (left_comparison.op, right_comparison.op) {
(Operator::GtEq, Operator::LtEq) => {
// We have x >= y && a <= b.
// If x == a then it is a between query
// if y == b then it is a between query
if left_comparison.left == right_comparison.left {
Some(Between {
expr: left_comparison.left.clone(),
low: left_comparison.right.clone(),
high: right_comparison.right.clone(),
negated: false,
})
} else if left_comparison.right == right_comparison.right {
Some(Between {
expr: left_comparison.right.clone(),
low: right_comparison.left.clone(),
high: left_comparison.left.clone(),
negated: false,
})
} else {
None
}
}
(Operator::LtEq, Operator::GtEq) => {
// Same logic as above we just switch the low/high
if left_comparison.left == right_comparison.left {
Some(Between {
expr: left_comparison.left.clone(),
low: right_comparison.right.clone(),
high: left_comparison.right.clone(),
negated: false,
})
} else if left_comparison.right == right_comparison.right {
Some(Between {
expr: left_comparison.right.clone(),
low: left_comparison.left.clone(),
high: right_comparison.left.clone(),
negated: false,
})
} else {
None
}
}
_ => None,
}
}

fn visit_and(
expr: &BinaryExpr,
index_info: &dyn IndexInformationProvider,
) -> Option<IndexedExpression> {
// Many scalar indices can efficiently handle a BETWEEN query as a single search and this
// can be much more efficient than two separate range queries. As an optimization we check
// to see if this is a between query and, if so, we handle it as a single query
//
// Note: We can't rely on users writing the SQL BETWEEN operator because:
// * Some users won't realize it's an option or a good idea
// * Datafusion's simplifier will rewrite the BETWEEN operator into two separate range queries
if let Some(between) = maybe_between(expr) {
return visit_between(&between, index_info);
}

let left = visit_node(&expr.left, index_info);
let right = visit_node(&expr.right, index_info);
match (left, right) {
Expand Down Expand Up @@ -912,6 +982,7 @@ mod tests {
]);

check_no_index(&index_info, "size BETWEEN 5 AND 10");
// 5 different ways of writing BETWEEN (all should be recognized)
check_simple(
&index_info,
"aisle BETWEEN 5 AND 10",
Expand All @@ -921,6 +992,45 @@ mod tests {
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);
check_simple(
&index_info,
"aisle >= 5 AND aisle <= 10",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);

check_simple(
&index_info,
"aisle <= 10 AND aisle >= 5",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);

check_simple(
&index_info,
"5 <= aisle AND 10 >= aisle",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);

check_simple(
&index_info,
"10 >= aisle AND 5 <= aisle",
"aisle",
SargableQuery::Range(
Bound::Included(ScalarValue::UInt32(Some(5))),
Bound::Included(ScalarValue::UInt32(Some(10))),
),
);
check_simple(
&index_info,
"on_sale IS TRUE",
Expand Down Expand Up @@ -1023,6 +1133,10 @@ mod tests {
Bound::Unbounded,
),
);
// In the future we can handle this case if we need to. For
// now let's make sure we don't accidentally do the wrong thing
// (we were getting this backwards in the past)
check_no_index(&index_info, "10 > aisle");
check_simple(
&index_info,
"aisle >= 10",
Expand Down