diff --git a/python/python/tests/test_filter.py b/python/python/tests/test_filter.py index e9096599c5c..5ca6e645e49 100644 --- a/python/python/tests/test_filter.py +++ b/python/python/tests/test_filter.py @@ -81,6 +81,7 @@ def test_sql_predicates(dataset): ("int >= 50", 50), ("int = 50", 1), ("int != 50", 99), + ("int BETWEEN 50 AND 60", 11), ("float < 30.0", 45), ("str = 'aa'", 16), ("str in ('aa', 'bb')", 26), diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index 3777c90d489..dd3df96d641 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -86,6 +86,25 @@ def test_indexed_scalar_scan(indexed_dataset: lance.LanceDataset, data_table: pa assert actual_price == expected_price +def test_indexed_between(tmp_path): + dataset = lance.write_dataset(pa.table({"val": range(100)}), tmp_path) + dataset.create_scalar_index("val", index_type="BTREE") + + scanner = dataset.scanner(filter="val BETWEEN 10 AND 20", prefilter=True) + + assert "MaterializeIndex" in scanner.explain_plan() + + actual_data = scanner.to_table() + assert actual_data.num_rows == 11 + + scanner = dataset.scanner(filter="val >= 10 AND val <= 20", prefilter=True) + + assert "MaterializeIndex" in scanner.explain_plan() + + actual_data = scanner.to_table() + assert actual_data.num_rows == 11 + + def test_temporal_index(tmp_path): # Timestamps now = datetime.now() diff --git a/rust/lance-datafusion/src/planner.rs b/rust/lance-datafusion/src/planner.rs index a8d985d82a8..e9237f1aa2e 100644 --- a/rust/lance-datafusion/src/planner.rs +++ b/rust/lance-datafusion/src/planner.rs @@ -40,7 +40,7 @@ use datafusion::sql::sqlparser::ast::{ }; use datafusion::{ common::Column, - logical_expr::{col, BinaryExpr, Like, Operator}, + logical_expr::{col, Between, BinaryExpr, Like, Operator}, physical_expr::execution_props::ExecutionProps, physical_plan::PhysicalExpr, prelude::Expr, @@ -746,6 +746,25 @@ impl Planner { let field_access_expr = RawFieldAccessExpr { expr, field_access }; self.plan_field_access(field_access_expr) } + SQLExpr::Between { + expr, + negated, + low, + high, + } => { + // Parse the main expression and bounds + let expr = self.parse_sql_expr(expr)?; + let low = self.parse_sql_expr(low)?; + let high = self.parse_sql_expr(high)?; + + let between = Expr::Between(Between::new( + Box::new(expr), + *negated, + Box::new(low), + Box::new(high), + )); + Ok(between) + } _ => Err(Error::invalid_input( format!("Expression '{expr}' is not supported SQL in lance"), location!(), @@ -1463,6 +1482,98 @@ mod tests { } } + #[test] + fn test_sql_between() { + use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray}; + use arrow_schema::{DataType, Field, Schema, TimeUnit}; + use std::sync::Arc; + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Float64, false), + Field::new( + "ts", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + ])); + + let planner = Planner::new(schema.clone()); + + // Test integer BETWEEN + let expr = planner + .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)") + .unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + // Create timestamp array with values representing: + // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds) + let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00 + let ts_array = TimestampMicrosecondArray::from_iter_values( + (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart + ); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef, + Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))), + Arc::new(ts_array), + ], + ) + .unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, true, false, false + ]) + ); + + // Test NOT BETWEEN + let expr = planner + .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)") + .unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + true, true, true, false, false, false, false, false, true, true + ]) + ); + + // Test floating point BETWEEN + let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, false, false, false + ]) + ); + + // Test timestamp BETWEEN + let expr = planner + .parse_filter( + "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'", + ) + .unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, true, false, false + ]) + ); + } + #[test] fn test_sql_comparison() { // Create a batch with all data types diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index 24bbbd7cc0d..3aa05580032 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -16,6 +16,7 @@ use datafusion_expr::{ use futures::join; use lance_core::{utils::mask::RowIdMask, Result}; use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner}; +use log::warn; use tracing::instrument; use super::{AnyQuery, LabelListQuery, SargableQuery, ScalarIndex}; @@ -564,9 +565,67 @@ fn visit_comparison( let scalar = maybe_scalar(&expr.right, col_type)?; query_parser.visit_comparison(column, scalar, &expr.op) } else { - let (column, col_type, query_parser) = maybe_indexed_column(&expr.right, index_info)?; - let scalar = maybe_scalar(&expr.left, col_type)?; - query_parser.visit_comparison(column, scalar, &expr.op) + // Datafusion's query simplifier will canonicalize expressions and so we shouldn't reach this case. If, for some reason, we + // do reach this case we can handle it in the future by inverting expr.op and swapping the left and right sides + warn!("Unexpected comparison encountered (DF simplifier should have removed this case). Scalar indices will not be applied"); + None + } +} + +fn maybe_between(expr: &BinaryExpr) -> Option { + let left_comparison = match expr.left.as_ref() { + Expr::BinaryExpr(binary_expr) => Some(binary_expr), + _ => None, + }?; + let right_comparison = match expr.right.as_ref() { + Expr::BinaryExpr(binary_expr) => Some(binary_expr), + _ => None, + }?; + + match (left_comparison.op, right_comparison.op) { + (Operator::GtEq, Operator::LtEq) => { + // We have x >= y && a <= b. + // If x == a then it is a between query + // if y == b then it is a between query + if left_comparison.left == right_comparison.left { + Some(Between { + expr: left_comparison.left.clone(), + low: left_comparison.right.clone(), + high: right_comparison.right.clone(), + negated: false, + }) + } else if left_comparison.right == right_comparison.right { + Some(Between { + expr: left_comparison.right.clone(), + low: right_comparison.left.clone(), + high: left_comparison.left.clone(), + negated: false, + }) + } else { + None + } + } + (Operator::LtEq, Operator::GtEq) => { + // Same logic as above we just switch the low/high + if left_comparison.left == right_comparison.left { + Some(Between { + expr: left_comparison.left.clone(), + low: right_comparison.right.clone(), + high: left_comparison.right.clone(), + negated: false, + }) + } else if left_comparison.right == right_comparison.right { + Some(Between { + expr: left_comparison.right.clone(), + low: left_comparison.left.clone(), + high: right_comparison.left.clone(), + negated: false, + }) + } else { + None + } + } + _ => None, } } @@ -574,6 +633,17 @@ fn visit_and( expr: &BinaryExpr, index_info: &dyn IndexInformationProvider, ) -> Option { + // Many scalar indices can efficiently handle a BETWEEN query as a single search and this + // can be much more efficient than two separate range queries. As an optimization we check + // to see if this is a between query and, if so, we handle it as a single query + // + // Note: We can't rely on users writing the SQL BETWEEN operator because: + // * Some users won't realize it's an option or a good idea + // * Datafusion's simplifier will rewrite the BETWEEN operator into two separate range queries + if let Some(between) = maybe_between(expr) { + return visit_between(&between, index_info); + } + let left = visit_node(&expr.left, index_info); let right = visit_node(&expr.right, index_info); match (left, right) { @@ -912,6 +982,7 @@ mod tests { ]); check_no_index(&index_info, "size BETWEEN 5 AND 10"); + // 5 different ways of writing BETWEEN (all should be recognized) check_simple( &index_info, "aisle BETWEEN 5 AND 10", @@ -921,6 +992,45 @@ mod tests { Bound::Included(ScalarValue::UInt32(Some(10))), ), ); + check_simple( + &index_info, + "aisle >= 5 AND aisle <= 10", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_simple( + &index_info, + "aisle <= 10 AND aisle >= 5", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_simple( + &index_info, + "5 <= aisle AND 10 >= aisle", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_simple( + &index_info, + "10 >= aisle AND 5 <= aisle", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); check_simple( &index_info, "on_sale IS TRUE", @@ -1023,6 +1133,10 @@ mod tests { Bound::Unbounded, ), ); + // In the future we can handle this case if we need to. For + // now let's make sure we don't accidentally do the wrong thing + // (we were getting this backwards in the past) + check_no_index(&index_info, "10 > aisle"); check_simple( &index_info, "aisle >= 10",