From 9553e64538907eef32d71d46defaba2507b757e2 Mon Sep 17 00:00:00 2001 From: Connell Gough Date: Mon, 9 Dec 2024 20:30:52 -0600 Subject: [PATCH 1/4] feat: support between sql clauses --- rust/lance-datafusion/src/planner.rs | 106 ++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/rust/lance-datafusion/src/planner.rs b/rust/lance-datafusion/src/planner.rs index a8d985d82a8..f7a2d1e1786 100644 --- a/rust/lance-datafusion/src/planner.rs +++ b/rust/lance-datafusion/src/planner.rs @@ -40,7 +40,7 @@ use datafusion::sql::sqlparser::ast::{ }; use datafusion::{ common::Column, - logical_expr::{col, BinaryExpr, Like, Operator}, + logical_expr::{col, BinaryExpr, Like, Operator, Between}, physical_expr::execution_props::ExecutionProps, physical_plan::PhysicalExpr, prelude::Expr, @@ -746,6 +746,25 @@ impl Planner { let field_access_expr = RawFieldAccessExpr { expr, field_access }; self.plan_field_access(field_access_expr) } + SQLExpr::Between { + expr, + negated, + low, + high, + } => { + // Parse the main expression and bounds + let expr = self.parse_sql_expr(expr)?; + let low = self.parse_sql_expr(low)?; + let high = self.parse_sql_expr(high)?; + + let foo = Expr::Between(Between::new( + Box::new(expr), + *negated, + Box::new(low), + Box::new(high), + )); + Ok(foo) + } _ => Err(Error::invalid_input( format!("Expression '{expr}' is not supported SQL in lance"), location!(), @@ -1463,6 +1482,91 @@ mod tests { } } + + #[test] +fn test_sql_between() { + use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray}; + use arrow_schema::{DataType, Field, Schema, TimeUnit}; + use std::sync::Arc; + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Float64, false), + Field::new("ts", DataType::Timestamp(TimeUnit::Microsecond, None), false), + ])); + + let planner = Planner::new(schema.clone()); + + // Test integer BETWEEN + let expr = planner.parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)").unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + // Create timestamp array with values representing: + // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds) + let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00 + let ts_array = TimestampMicrosecondArray::from_iter_values( + (0..10).map(|i| base_ts + i * 1_000_000) // Each value is 1 second apart + ); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef, + Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))), + Arc::new(ts_array), + ], + ) + .unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, true, false, false + ]) + ); + + // Test NOT BETWEEN + let expr = planner.parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)").unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + true, true, true, false, false, false, false, false, true, true + ]) + ); + + // Test floating point BETWEEN + let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, false, false, false + ]) + ); + + // Test timestamp BETWEEN + let expr = planner + .parse_filter( + "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'" + ) + .unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, true, false, false + ]) + ); +} + #[test] fn test_sql_comparison() { // Create a batch with all data types From 46bb5ea6d96387687fa793adbc5226d112d1d4c9 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 10 Dec 2024 06:30:56 -0800 Subject: [PATCH 2/4] Added optimization in the expr->scalar index stage to detect anything that looks like a between query and use the between correctly. Minor formatting changes. --- python/python/tests/test_filter.py | 1 + python/python/tests/test_scalar_index.py | 19 +++ rust/lance-datafusion/src/planner.rs | 177 +++++++++++----------- rust/lance-index/src/scalar/expression.rs | 120 ++++++++++++++- 4 files changed, 229 insertions(+), 88 deletions(-) diff --git a/python/python/tests/test_filter.py b/python/python/tests/test_filter.py index e9096599c5c..5ca6e645e49 100644 --- a/python/python/tests/test_filter.py +++ b/python/python/tests/test_filter.py @@ -81,6 +81,7 @@ def test_sql_predicates(dataset): ("int >= 50", 50), ("int = 50", 1), ("int != 50", 99), + ("int BETWEEN 50 AND 60", 11), ("float < 30.0", 45), ("str = 'aa'", 16), ("str in ('aa', 'bb')", 26), diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index 3777c90d489..dd3df96d641 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -86,6 +86,25 @@ def test_indexed_scalar_scan(indexed_dataset: lance.LanceDataset, data_table: pa assert actual_price == expected_price +def test_indexed_between(tmp_path): + dataset = lance.write_dataset(pa.table({"val": range(100)}), tmp_path) + dataset.create_scalar_index("val", index_type="BTREE") + + scanner = dataset.scanner(filter="val BETWEEN 10 AND 20", prefilter=True) + + assert "MaterializeIndex" in scanner.explain_plan() + + actual_data = scanner.to_table() + assert actual_data.num_rows == 11 + + scanner = dataset.scanner(filter="val >= 10 AND val <= 20", prefilter=True) + + assert "MaterializeIndex" in scanner.explain_plan() + + actual_data = scanner.to_table() + assert actual_data.num_rows == 11 + + def test_temporal_index(tmp_path): # Timestamps now = datetime.now() diff --git a/rust/lance-datafusion/src/planner.rs b/rust/lance-datafusion/src/planner.rs index f7a2d1e1786..2d17c290dab 100644 --- a/rust/lance-datafusion/src/planner.rs +++ b/rust/lance-datafusion/src/planner.rs @@ -40,7 +40,7 @@ use datafusion::sql::sqlparser::ast::{ }; use datafusion::{ common::Column, - logical_expr::{col, BinaryExpr, Like, Operator, Between}, + logical_expr::{col, Between, BinaryExpr, Like, Operator}, physical_expr::execution_props::ExecutionProps, physical_plan::PhysicalExpr, prelude::Expr, @@ -758,10 +758,10 @@ impl Planner { let high = self.parse_sql_expr(high)?; let foo = Expr::Between(Between::new( - Box::new(expr), + Box::new(expr), *negated, - Box::new(low), - Box::new(high), + Box::new(low), + Box::new(high), )); Ok(foo) } @@ -1482,90 +1482,97 @@ mod tests { } } - #[test] -fn test_sql_between() { - use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray}; - use arrow_schema::{DataType, Field, Schema, TimeUnit}; - use std::sync::Arc; - - let schema = Arc::new(Schema::new(vec![ - Field::new("x", DataType::Int32, false), - Field::new("y", DataType::Float64, false), - Field::new("ts", DataType::Timestamp(TimeUnit::Microsecond, None), false), - ])); - - let planner = Planner::new(schema.clone()); - - // Test integer BETWEEN - let expr = planner.parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)").unwrap(); - let physical_expr = planner.create_physical_expr(&expr).unwrap(); - - // Create timestamp array with values representing: - // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds) - let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00 - let ts_array = TimestampMicrosecondArray::from_iter_values( - (0..10).map(|i| base_ts + i * 1_000_000) // Each value is 1 second apart - ); - - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef, - Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))), - Arc::new(ts_array), - ], - ) - .unwrap(); - - let predicates = physical_expr.evaluate(&batch).unwrap(); - assert_eq!( - predicates.into_array(0).unwrap().as_ref(), - &BooleanArray::from(vec![ - false, false, false, true, true, true, true, true, false, false - ]) - ); - - // Test NOT BETWEEN - let expr = planner.parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)").unwrap(); - let physical_expr = planner.create_physical_expr(&expr).unwrap(); - - let predicates = physical_expr.evaluate(&batch).unwrap(); - assert_eq!( - predicates.into_array(0).unwrap().as_ref(), - &BooleanArray::from(vec![ - true, true, true, false, false, false, false, false, true, true - ]) - ); - - // Test floating point BETWEEN - let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap(); - let physical_expr = planner.create_physical_expr(&expr).unwrap(); - - let predicates = physical_expr.evaluate(&batch).unwrap(); - assert_eq!( - predicates.into_array(0).unwrap().as_ref(), - &BooleanArray::from(vec![ - false, false, false, true, true, true, true, false, false, false - ]) - ); - - // Test timestamp BETWEEN - let expr = planner - .parse_filter( - "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'" + fn test_sql_between() { + use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray}; + use arrow_schema::{DataType, Field, Schema, TimeUnit}; + use std::sync::Arc; + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Float64, false), + Field::new( + "ts", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + ])); + + let planner = Planner::new(schema.clone()); + + // Test integer BETWEEN + let expr = planner + .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)") + .unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + // Create timestamp array with values representing: + // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds) + let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00 + let ts_array = TimestampMicrosecondArray::from_iter_values( + (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart + ); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef, + Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))), + Arc::new(ts_array), + ], ) .unwrap(); - let physical_expr = planner.create_physical_expr(&expr).unwrap(); - - let predicates = physical_expr.evaluate(&batch).unwrap(); - assert_eq!( - predicates.into_array(0).unwrap().as_ref(), - &BooleanArray::from(vec![ - false, false, false, true, true, true, true, true, false, false - ]) - ); -} + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, true, false, false + ]) + ); + + // Test NOT BETWEEN + let expr = planner + .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)") + .unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + true, true, true, false, false, false, false, false, true, true + ]) + ); + + // Test floating point BETWEEN + let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, false, false, false + ]) + ); + + // Test timestamp BETWEEN + let expr = planner + .parse_filter( + "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'", + ) + .unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, true, false, false + ]) + ); + } #[test] fn test_sql_comparison() { diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index 24bbbd7cc0d..3aa05580032 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -16,6 +16,7 @@ use datafusion_expr::{ use futures::join; use lance_core::{utils::mask::RowIdMask, Result}; use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner}; +use log::warn; use tracing::instrument; use super::{AnyQuery, LabelListQuery, SargableQuery, ScalarIndex}; @@ -564,9 +565,67 @@ fn visit_comparison( let scalar = maybe_scalar(&expr.right, col_type)?; query_parser.visit_comparison(column, scalar, &expr.op) } else { - let (column, col_type, query_parser) = maybe_indexed_column(&expr.right, index_info)?; - let scalar = maybe_scalar(&expr.left, col_type)?; - query_parser.visit_comparison(column, scalar, &expr.op) + // Datafusion's query simplifier will canonicalize expressions and so we shouldn't reach this case. If, for some reason, we + // do reach this case we can handle it in the future by inverting expr.op and swapping the left and right sides + warn!("Unexpected comparison encountered (DF simplifier should have removed this case). Scalar indices will not be applied"); + None + } +} + +fn maybe_between(expr: &BinaryExpr) -> Option { + let left_comparison = match expr.left.as_ref() { + Expr::BinaryExpr(binary_expr) => Some(binary_expr), + _ => None, + }?; + let right_comparison = match expr.right.as_ref() { + Expr::BinaryExpr(binary_expr) => Some(binary_expr), + _ => None, + }?; + + match (left_comparison.op, right_comparison.op) { + (Operator::GtEq, Operator::LtEq) => { + // We have x >= y && a <= b. + // If x == a then it is a between query + // if y == b then it is a between query + if left_comparison.left == right_comparison.left { + Some(Between { + expr: left_comparison.left.clone(), + low: left_comparison.right.clone(), + high: right_comparison.right.clone(), + negated: false, + }) + } else if left_comparison.right == right_comparison.right { + Some(Between { + expr: left_comparison.right.clone(), + low: right_comparison.left.clone(), + high: left_comparison.left.clone(), + negated: false, + }) + } else { + None + } + } + (Operator::LtEq, Operator::GtEq) => { + // Same logic as above we just switch the low/high + if left_comparison.left == right_comparison.left { + Some(Between { + expr: left_comparison.left.clone(), + low: right_comparison.right.clone(), + high: left_comparison.right.clone(), + negated: false, + }) + } else if left_comparison.right == right_comparison.right { + Some(Between { + expr: left_comparison.right.clone(), + low: left_comparison.left.clone(), + high: right_comparison.left.clone(), + negated: false, + }) + } else { + None + } + } + _ => None, } } @@ -574,6 +633,17 @@ fn visit_and( expr: &BinaryExpr, index_info: &dyn IndexInformationProvider, ) -> Option { + // Many scalar indices can efficiently handle a BETWEEN query as a single search and this + // can be much more efficient than two separate range queries. As an optimization we check + // to see if this is a between query and, if so, we handle it as a single query + // + // Note: We can't rely on users writing the SQL BETWEEN operator because: + // * Some users won't realize it's an option or a good idea + // * Datafusion's simplifier will rewrite the BETWEEN operator into two separate range queries + if let Some(between) = maybe_between(expr) { + return visit_between(&between, index_info); + } + let left = visit_node(&expr.left, index_info); let right = visit_node(&expr.right, index_info); match (left, right) { @@ -912,6 +982,7 @@ mod tests { ]); check_no_index(&index_info, "size BETWEEN 5 AND 10"); + // 5 different ways of writing BETWEEN (all should be recognized) check_simple( &index_info, "aisle BETWEEN 5 AND 10", @@ -921,6 +992,45 @@ mod tests { Bound::Included(ScalarValue::UInt32(Some(10))), ), ); + check_simple( + &index_info, + "aisle >= 5 AND aisle <= 10", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_simple( + &index_info, + "aisle <= 10 AND aisle >= 5", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_simple( + &index_info, + "5 <= aisle AND 10 >= aisle", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_simple( + &index_info, + "10 >= aisle AND 5 <= aisle", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); check_simple( &index_info, "on_sale IS TRUE", @@ -1023,6 +1133,10 @@ mod tests { Bound::Unbounded, ), ); + // In the future we can handle this case if we need to. For + // now let's make sure we don't accidentally do the wrong thing + // (we were getting this backwards in the past) + check_no_index(&index_info, "10 > aisle"); check_simple( &index_info, "aisle >= 10", From 7618f8da625073478903c2dfde99aa629ad8005c Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 10 Dec 2024 07:26:23 -0800 Subject: [PATCH 3/4] Address clippy suggestion --- rust/lance-datafusion/src/planner.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/lance-datafusion/src/planner.rs b/rust/lance-datafusion/src/planner.rs index 2d17c290dab..677f625b563 100644 --- a/rust/lance-datafusion/src/planner.rs +++ b/rust/lance-datafusion/src/planner.rs @@ -757,13 +757,13 @@ impl Planner { let low = self.parse_sql_expr(low)?; let high = self.parse_sql_expr(high)?; - let foo = Expr::Between(Between::new( + let between = Expr::Between(Between::new( Box::new(expr), *negated, Box::new(low), Box::new(high), )); - Ok(foo) + Ok(between) } _ => Err(Error::invalid_input( format!("Expression '{expr}' is not supported SQL in lance"), From 725f9fada42a4cf9d41d8d69a3c94e7de3424ad4 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 10 Dec 2024 09:17:06 -0800 Subject: [PATCH 4/4] Addressing clippy suggestion --- rust/lance-datafusion/src/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/lance-datafusion/src/planner.rs b/rust/lance-datafusion/src/planner.rs index 677f625b563..e9237f1aa2e 100644 --- a/rust/lance-datafusion/src/planner.rs +++ b/rust/lance-datafusion/src/planner.rs @@ -1514,7 +1514,7 @@ mod tests { ); let batch = RecordBatch::try_new( - schema.clone(), + schema, vec![ Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef, Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))),