From 96cd76d9bbd9eeecc44c2cf552347ce3158c2854 Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Fri, 8 Jan 2021 07:30:01 +1100 Subject: [PATCH] in_list --- rust/benchmarks/src/bin/tpch.rs | 2 +- rust/datafusion/src/logical_plan/expr.rs | 53 +++ rust/datafusion/src/logical_plan/mod.rs | 4 +- rust/datafusion/src/optimizer/utils.rs | 15 + .../src/physical_plan/expressions.rs | 397 +++++++++++++++++- rust/datafusion/src/physical_plan/planner.rs | 102 +++++ rust/datafusion/src/prelude.rs | 4 +- rust/datafusion/src/sql/planner.rs | 17 + rust/datafusion/src/sql/utils.rs | 26 ++ rust/datafusion/tests/sql.rs | 86 ++++ 10 files changed, 700 insertions(+), 6 deletions(-) diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index c7a2dce677f..539b8d23d08 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -656,7 +656,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result= date '1994-01-01' diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 66658e1bab3..f8e364a01f6 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -158,6 +158,15 @@ pub enum Expr { /// List of expressions to feed to the functions as arguments args: Vec, }, + /// Returns whether the list contains the expr value. + InList { + /// The expression to compare + expr: Box, + /// A list of values to compare against + list: Vec, + /// Whether the expression is negated + negated: bool, + }, /// Represents a reference to all fields in a schema. Wildcard, } @@ -224,6 +233,7 @@ impl Expr { ), Expr::Sort { ref expr, .. } => expr.get_type(schema), Expr::Between { .. } => Ok(DataType::Boolean), + Expr::InList { .. } => Ok(DataType::Boolean), Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), @@ -278,6 +288,7 @@ impl Expr { } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), Expr::Sort { ref expr, .. } => expr.nullable(input_schema), Expr::Between { ref expr, .. } => expr.nullable(input_schema), + Expr::InList { ref expr, .. } => expr.nullable(input_schema), Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), @@ -389,6 +400,15 @@ impl Expr { Expr::Alias(Box::new(self.clone()), name.to_owned()) } + /// InList + pub fn in_list(&self, list: Vec, negated: bool) -> Expr { + Expr::InList { + expr: Box::new(self.clone()), + list, + negated, + } + } + /// Create a sort expression from an existing expression. /// /// ``` @@ -579,6 +599,15 @@ pub fn count_distinct(expr: Expr) -> Expr { } } +/// Create an in_list expression +pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { + Expr::InList { + expr: Box::new(expr), + list, + negated, + } +} + /// Whether it can be represented as a literal expression pub trait Literal { /// convert the value to a Literal expression @@ -814,6 +843,17 @@ impl fmt::Debug for Expr { write!(f, "{:?} BETWEEN {:?} AND {:?}", expr, low, high) } } + Expr::InList { + expr, + list, + negated, + } => { + if *negated { + write!(f, "{:?} NOT IN ({:?})", expr, list) + } else { + write!(f, "{:?} IN ({:?})", expr, list) + } + } Expr::Wildcard => write!(f, "*"), } } @@ -906,6 +946,19 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { } Ok(format!("{}({})", fun.name, names.join(","))) } + Expr::InList { + expr, + list, + negated, + } => { + let expr = create_name(expr, input_schema)?; + let list = list.iter().map(|expr| create_name(expr, input_schema)); + if *negated { + Ok(format!("{:?} NOT IN ({:?})", expr, list)) + } else { + Ok(format!("{:?} IN ({:?})", expr, list)) + } + } other => Err(DataFusionError::NotImplemented(format!( "Physical plan does not support logical expression {:?}", other diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 9be3b9449e2..24c493bda8f 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -36,8 +36,8 @@ pub use display::display_schema; pub use expr::{ abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, concat, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, - length, lit, ln, log10, log2, lower, ltrim, max, min, or, round, rtrim, signum, sin, - sqrt, sum, tan, trim, trunc, upper, when, Expr, Literal, + in_list, length, lit, ln, log10, log2, lower, ltrim, max, min, or, round, rtrim, + signum, sin, sqrt, sum, tan, trim, trunc, upper, when, Expr, Literal, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 2e950fd2be3..75661c6f723 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -102,6 +102,13 @@ pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result< expr_to_column_names(high, accum)?; Ok(()) } + Expr::InList { expr, list, .. } => { + expr_to_column_names(expr, accum)?; + for list_expr in list { + expr_to_column_names(list_expr, accum)?; + } + Ok(()) + } Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), @@ -305,6 +312,13 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { low.as_ref().to_owned(), high.as_ref().to_owned(), ]), + Expr::InList { expr, list, .. } => { + let mut expr_list: Vec = vec![expr.as_ref().to_owned()]; + for list_expr in list { + expr_list.push(list_expr.to_owned()); + } + Ok(expr_list) + } Expr::Wildcard { .. } => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), @@ -416,6 +430,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &Vec) -> Result Ok(expr) } } + Expr::InList { .. } => Ok(expr.clone()), Expr::Wildcard { .. } => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index 468e4a2d161..2b2bea43fb2 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -26,7 +26,10 @@ use crate::error::{DataFusionError, Result}; use crate::logical_plan::Operator; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; use crate::scalar::ScalarValue; -use arrow::array::{self, Array, BooleanBuilder, LargeStringArray}; +use arrow::array::{ + self, Array, BooleanBuilder, GenericStringArray, LargeStringArray, + StringOffsetSizeTrait, +}; use arrow::compute; use arrow::compute::kernels; use arrow::compute::kernels::arithmetic::{add, divide, multiply, negate, subtract}; @@ -2414,6 +2417,236 @@ impl PhysicalSortExpr { } } +/// InList +#[derive(Debug)] +pub struct InListExpr { + expr: Arc, + list: Vec>, + negated: bool, +} + +macro_rules! make_contains { + ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr, $SCALAR_VALUE:ident, $ARRAY_TYPE:ident) => {{ + let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + + let mut contains_null = false; + let values = $LIST_VALUES + .iter() + .flat_map(|expr| match expr { + ColumnarValue::Scalar(s) => match s { + ScalarValue::$SCALAR_VALUE(Some(v)) => Some(*v), + ScalarValue::$SCALAR_VALUE(None) => { + contains_null = true; + None + } + ScalarValue::Utf8(None) => { + contains_null = true; + None + } + datatype => unimplemented!("Unexpected type {} for InList", datatype), + }, + ColumnarValue::Array(_) => { + unimplemented!("InList does not yet support nested columns.") + } + }) + .collect::>(); + + Ok(ColumnarValue::Array(Arc::new( + array + .iter() + .map(|x| { + let contains = x.map(|x| values.contains(&x)); + match contains { + Some(true) => { + if $NEGATED { + Some(false) + } else { + Some(true) + } + } + Some(false) => { + if contains_null { + None + } else if $NEGATED { + Some(true) + } else { + Some(false) + } + } + None => None, + } + }) + .collect::(), + ))) + }}; +} + +impl InListExpr { + /// Create a new InList expression + pub fn new( + expr: Arc, + list: Vec>, + negated: bool, + ) -> Self { + Self { + expr, + list, + negated, + } + } + + /// Compare for specific utf8 types + fn compare_utf8( + &self, + array: ArrayRef, + list_values: Vec, + negated: bool, + ) -> Result { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let mut contains_null = false; + let values = list_values + .iter() + .flat_map(|expr| match expr { + ColumnarValue::Scalar(s) => match s { + ScalarValue::Utf8(Some(v)) => Some(v.as_str()), + ScalarValue::Utf8(None) => { + contains_null = true; + None + } + ScalarValue::LargeUtf8(Some(v)) => Some(v.as_str()), + ScalarValue::LargeUtf8(None) => { + contains_null = true; + None + } + datatype => unimplemented!("Unexpected type {} for InList", datatype), + }, + ColumnarValue::Array(_) => { + unimplemented!("InList does not yet support nested columns.") + } + }) + .collect::>(); + + Ok(ColumnarValue::Array(Arc::new( + array + .iter() + .map(|x| { + let contains = x.map(|x| values.contains(&x)); + match contains { + Some(true) => { + if negated { + Some(false) + } else { + Some(true) + } + } + Some(false) => { + if contains_null { + None + } else if negated { + Some(true) + } else { + Some(false) + } + } + None => None, + } + }) + .collect::(), + ))) + } +} + +impl fmt::Display for InListExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.negated { + write!(f, "{} NOT IN ({:?})", self.expr, self.list) + } else { + write!(f, "{} IN ({:?})", self.expr, self.list) + } + } +} + +impl PhysicalExpr for InListExpr { + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Boolean) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.expr.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let value = self.expr.evaluate(batch)?; + let value_data_type = value.data_type(); + let list_values = self + .list + .iter() + .map(|expr| expr.evaluate(batch)) + .collect::>>()?; + + let array = match value { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => scalar.to_array(), + }; + + match value_data_type { + DataType::Float32 => { + make_contains!(array, list_values, self.negated, Float32, Float32Array) + } + DataType::Float64 => { + make_contains!(array, list_values, self.negated, Float64, Float64Array) + } + DataType::Int16 => { + make_contains!(array, list_values, self.negated, Int16, Int16Array) + } + DataType::Int32 => { + make_contains!(array, list_values, self.negated, Int32, Int32Array) + } + DataType::Int64 => { + make_contains!(array, list_values, self.negated, Int64, Int64Array) + } + DataType::Int8 => { + make_contains!(array, list_values, self.negated, Int8, Int8Array) + } + DataType::UInt16 => { + make_contains!(array, list_values, self.negated, UInt16, UInt16Array) + } + DataType::UInt32 => { + make_contains!(array, list_values, self.negated, UInt32, UInt32Array) + } + DataType::UInt64 => { + make_contains!(array, list_values, self.negated, UInt64, UInt64Array) + } + DataType::UInt8 => { + make_contains!(array, list_values, self.negated, UInt8, UInt8Array) + } + DataType::Boolean => { + make_contains!(array, list_values, self.negated, Boolean, BooleanArray) + } + DataType::Utf8 => self.compare_utf8::(array, list_values, self.negated), + DataType::LargeUtf8 => { + self.compare_utf8::(array, list_values, self.negated) + } + datatype => { + unimplemented!("InList does not support datatype {:?}.", datatype) + } + } + } +} + +/// Creates a unary expression InList +pub fn in_list( + expr: Arc, + list: Vec>, + negated: &bool, +) -> Result> { + Ok(Arc::new(InListExpr::new(expr, list, *negated))) +} + #[cfg(test)] mod tests { use super::*; @@ -3769,4 +4002,166 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; Ok(batch) } + + // applies the in_list expr to an input batch and list + macro_rules! in_list { + ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr) => {{ + let expr = in_list(col("a"), $LIST, $NEGATED).unwrap(); + let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + let expected = &BooleanArray::from($EXPECTED); + assert_eq!(expected, result); + }}; + } + + #[test] + fn in_list_utf8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let a = StringArray::from(vec![Some("a"), Some("d"), None]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + // expression: "a in ("a", "b")" + let list = vec![ + lit(ScalarValue::Utf8(Some("a".to_string()))), + lit(ScalarValue::Utf8(Some("b".to_string()))), + ]; + in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + + // expression: "a not in ("a", "b")" + let list = vec![ + lit(ScalarValue::Utf8(Some("a".to_string()))), + lit(ScalarValue::Utf8(Some("b".to_string()))), + ]; + in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + + // expression: "a not in ("a", "b")" + let list = vec![ + lit(ScalarValue::Utf8(Some("a".to_string()))), + lit(ScalarValue::Utf8(Some("b".to_string()))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &false, vec![Some(true), None, None]); + + // expression: "a not in ("a", "b")" + let list = vec![ + lit(ScalarValue::Utf8(Some("a".to_string()))), + lit(ScalarValue::Utf8(Some("b".to_string()))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &true, vec![Some(false), None, None]); + + Ok(()) + } + + #[test] + fn in_list_int64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let a = Int64Array::from(vec![Some(0), Some(2), None]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![ + lit(ScalarValue::Int64(Some(0))), + lit(ScalarValue::Int64(Some(1))), + ]; + in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + + // expression: "a not in (0, 1)" + let list = vec![ + lit(ScalarValue::Int64(Some(0))), + lit(ScalarValue::Int64(Some(1))), + ]; + in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + + // expression: "a in (0, 1, NULL)" + let list = vec![ + lit(ScalarValue::Int64(Some(0))), + lit(ScalarValue::Int64(Some(1))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &false, vec![Some(true), None, None]); + + // expression: "a not in (0, 1, NULL)" + let list = vec![ + lit(ScalarValue::Int64(Some(0))), + lit(ScalarValue::Int64(Some(1))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &true, vec![Some(false), None, None]); + + Ok(()) + } + + #[test] + fn in_list_float64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + // expression: "a in (0.0, 0.2)" + let list = vec![ + lit(ScalarValue::Float64(Some(0.0))), + lit(ScalarValue::Float64(Some(0.1))), + ]; + in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + + // expression: "a not in (0.0, 0.2)" + let list = vec![ + lit(ScalarValue::Float64(Some(0.0))), + lit(ScalarValue::Float64(Some(0.1))), + ]; + in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + + // expression: "a in (0.0, 0.2, NULL)" + let list = vec![ + lit(ScalarValue::Float64(Some(0.0))), + lit(ScalarValue::Float64(Some(0.1))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &false, vec![Some(true), None, None]); + + // expression: "a not in (0.0, 0.2, NULL)" + let list = vec![ + lit(ScalarValue::Float64(Some(0.0))), + lit(ScalarValue::Float64(Some(0.1))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &true, vec![Some(false), None, None]); + + Ok(()) + } + + #[test] + fn in_list_bool() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); + let a = BooleanArray::from(vec![Some(true), None]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + // expression: "a in (true)" + let list = vec![lit(ScalarValue::Boolean(Some(true)))]; + in_list!(batch, list, &false, vec![Some(true), None]); + + // expression: "a not in (true)" + let list = vec![lit(ScalarValue::Boolean(Some(true)))]; + in_list!(batch, list, &true, vec![Some(false), None]); + + // expression: "a in (true, NULL)" + let list = vec![ + lit(ScalarValue::Boolean(Some(true))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &false, vec![Some(true), None]); + + // expression: "a not in (true, NULL)" + let list = vec![ + lit(ScalarValue::Boolean(Some(true))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &true, vec![Some(false), None]); + + Ok(()) + } } diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index ad866809b1b..6af2c485b72 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -42,7 +42,10 @@ use crate::physical_plan::{expressions, Distribution}; use crate::physical_plan::{hash_utils, Partitioning}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, PhysicalPlanner}; use crate::prelude::JoinType; +use crate::scalar::ScalarValue; use crate::variable::VarType; +use arrow::compute::can_cast_types; + use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use expressions::col; @@ -593,6 +596,57 @@ impl DefaultPhysicalPlanner { binary_expr } } + Expr::InList { + expr, + list, + negated, + } => match expr.as_ref() { + Expr::Literal(ScalarValue::Utf8(None)) => { + Ok(expressions::lit(ScalarValue::Boolean(None))) + } + _ => { + let value_expr = + self.create_physical_expr(expr, input_schema, ctx_state)?; + let value_expr_data_type = value_expr.data_type(input_schema)?; + + let list_exprs = + list.iter() + .map(|expr| match expr { + Expr::Literal(ScalarValue::Utf8(None)) => self + .create_physical_expr(expr, input_schema, ctx_state), + _ => { + let list_expr = self.create_physical_expr( + expr, + input_schema, + ctx_state, + )?; + let list_expr_data_type = + list_expr.data_type(input_schema)?; + + if list_expr_data_type == value_expr_data_type { + Ok(list_expr) + } else if can_cast_types( + &list_expr_data_type, + &value_expr_data_type, + ) { + expressions::cast( + list_expr, + input_schema, + value_expr.data_type(input_schema)?, + ) + } else { + Err(DataFusionError::Plan(format!( + "Unsupported CAST from {:?} to {:?}", + list_expr_data_type, value_expr_data_type + ))) + } + } + }) + .collect::>>()?; + + expressions::in_list(value_expr, list_exprs, negated) + } + }, other => Err(DataFusionError::NotImplemented(format!( "Physical plan does not support logical expression {:?}", other @@ -699,6 +753,7 @@ mod tests { use crate::logical_plan::{DFField, DFSchema, DFSchemaRef}; use crate::physical_plan::{csv::CsvReadOptions, expressions, Partitioning}; use crate::prelude::ExecutionConfig; + use crate::scalar::ScalarValue; use crate::{ logical_plan::{col, lit, sum, LogicalPlanBuilder}, physical_plan::SendableRecordBatchStream, @@ -884,6 +939,53 @@ mod tests { Ok(()) } + #[test] + fn in_list_types() -> Result<()> { + let testdata = arrow::util::test_util::arrow_test_data(); + let path = format!("{}/csv/aggregate_test_100.csv", testdata); + let options = CsvReadOptions::new().schema_infer_max_records(100); + + // expression: "a in ('a', 1)" + let list = vec![ + Expr::Literal(ScalarValue::Utf8(Some("a".to_string()))), + Expr::Literal(ScalarValue::Int64(Some(1))), + ]; + let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? + // filter clause needs the type coercion rule applied + .filter(col("c12").lt(lit(0.05)))? + .project(vec![col("c1").in_list(list, false)])? + .build()?; + let execution_plan = plan(&logical_plan)?; + // verify that the plan correctly adds cast from Int64(1) to Utf8 + let expected = "InListExpr { expr: Column { name: \"c1\" }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8 }], negated: false }"; + assert!(format!("{:?}", execution_plan).contains(expected)); + + // expression: "a in (true, 'a')" + let list = vec![ + Expr::Literal(ScalarValue::Boolean(Some(true))), + Expr::Literal(ScalarValue::Utf8(Some("a".to_string()))), + ]; + let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? + // filter clause needs the type coercion rule applied + .filter(col("c12").lt(lit(0.05)))? + .project(vec![col("c12").lt_eq(lit(0.025)).in_list(list, false)])? + .build()?; + let execution_plan = plan(&logical_plan); + + let expected_error = "Unsupported CAST from Utf8 to Boolean"; + match execution_plan { + Ok(_) => panic!("Expected planning failure"), + Err(e) => assert!( + e.to_string().contains(expected_error), + "Error '{}' did not contain expected error '{}'", + e.to_string(), + expected_error + ), + } + + Ok(()) + } + /// An example extension node that doesn't do anything struct NoOpExtensionNode { schema: DFSchemaRef, diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index b607713c777..1879bd5cd21 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,7 +28,7 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, col, concat, count, create_udf, length, lit, lower, ltrim, max, min, - rtrim, sum, trim, upper, JoinType, Partitioning, + array, avg, col, concat, count, create_udf, in_list, length, lit, lower, ltrim, max, + min, rtrim, sum, trim, upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index dafff3dcf70..d2cbd7c4f61 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -759,6 +759,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { high: Box::new(self.sql_expr_to_logical_expr(&high)?), }), + SQLExpr::InList { + ref expr, + ref list, + ref negated, + } => { + let list_expr = list + .iter() + .map(|e| self.sql_expr_to_logical_expr(e)) + .collect::>>()?; + + Ok(Expr::InList { + expr: Box::new(self.sql_expr_to_logical_expr(&expr)?), + list: list_expr, + negated: *negated, + }) + } + SQLExpr::BinaryOp { ref left, ref op, diff --git a/rust/datafusion/src/sql/utils.rs b/rust/datafusion/src/sql/utils.rs index 0d12dddf410..ce8b4d1e01f 100644 --- a/rust/datafusion/src/sql/utils.rs +++ b/rust/datafusion/src/sql/utils.rs @@ -100,6 +100,20 @@ where matches.extend(find_exprs_in_expr(right.as_ref(), test_fn)); matches } + Expr::InList { + expr: nested_expr, + list, + .. + } => { + let mut matches = vec![]; + matches.extend(find_exprs_in_expr(nested_expr.as_ref(), test_fn)); + matches.extend( + list.iter() + .flat_map(|expr| find_exprs_in_expr(expr, test_fn)) + .collect::>(), + ); + matches + } Expr::Case { expr: case_expr_opt, when_then_expr, @@ -277,6 +291,18 @@ where low: Box::new(clone_with_replacement(&**low, replacement_fn)?), high: Box::new(clone_with_replacement(&**high, replacement_fn)?), }), + Expr::InList { + expr: nested_expr, + list, + negated, + } => Ok(Expr::InList { + expr: Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?), + list: list + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + negated: *negated, + }), Expr::BinaryExpr { left, right, op } => Ok(Expr::BinaryExpr { left: Box::new(clone_with_replacement(&**left, replacement_fn)?), op: op.clone(), diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index b810c428d54..2904b748a5d 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1853,3 +1853,89 @@ async fn string_expressions() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[tokio::test] +async fn in_list_array() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "SELECT + c1 IN ('a', 'c') AS utf8_in_true + ,c1 IN ('x', 'y') AS utf8_in_false + ,c1 NOT IN ('x', 'y') AS utf8_not_in_true + ,c1 NOT IN ('a', 'c') AS utf8_not_in_false + ,CAST(CAST(c1 AS int) AS varchar) IN ('a', 'c') AS utf8_in_null + FROM aggregate_test_100 WHERE c12 < 0.05"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["true", "false", "true", "false", "NULL"], + vec!["true", "false", "true", "false", "NULL"], + vec!["true", "false", "true", "false", "NULL"], + vec!["false", "false", "true", "true", "NULL"], + vec!["false", "false", "true", "true", "NULL"], + vec!["false", "false", "true", "true", "NULL"], + vec!["false", "false", "true", "true", "NULL"], + ]; + assert_eq!(expected, actual); + Ok(()) +} + +#[tokio::test] +async fn in_list_scalar() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT + 'a' IN ('a','b') AS utf8_in_true + ,'c' IN ('a','b') AS utf8_in_false + ,'c' NOT IN ('a','b') AS utf8_not_in_true + ,'a' NOT IN ('a','b') AS utf8_not_in_false + ,NULL IN ('a','b') AS utf8_in_null + ,NULL NOT IN ('a','b') AS utf8_not_in_null + ,'a' IN ('a','b',NULL) AS utf8_in_null_true + ,'c' IN ('a','b',NULL) AS utf8_in_null_null + ,'a' NOT IN ('a','b',NULL) AS utf8_not_in_null_false + ,'c' NOT IN ('a','b',NULL) AS utf8_not_in_null_null + + ,0 IN (0,1,2) AS int64_in_true + ,3 IN (0,1,2) AS int64_in_false + ,3 NOT IN (0,1,2) AS int64_not_in_true + ,0 NOT IN (0,1,2) AS int64_not_in_false + ,NULL IN (0,1,2) AS int64_in_null + ,NULL NOT IN (0,1,2) AS int64_not_in_null + ,0 IN (0,1,2,NULL) AS int64_in_null_true + ,3 IN (0,1,2,NULL) AS int64_in_null_null + ,0 NOT IN (0,1,2,NULL) AS int64_not_in_null_false + ,3 NOT IN (0,1,2,NULL) AS int64_not_in_null_null + + ,0.0 IN (0.0,0.1,0.2) AS float64_in_true + ,0.3 IN (0.0,0.1,0.2) AS float64_in_false + ,0.3 NOT IN (0.0,0.1,0.2) AS float64_not_in_true + ,0.0 NOT IN (0.0,0.1,0.2) AS float64_not_in_false + ,NULL IN (0.0,0.1,0.2) AS float64_in_null + ,NULL NOT IN (0.0,0.1,0.2) AS float64_not_in_null + ,0.0 IN (0.0,0.1,0.2,NULL) AS float64_in_null_true + ,0.3 IN (0.0,0.1,0.2,NULL) AS float64_in_null_null + ,0.0 NOT IN (0.0,0.1,0.2,NULL) AS float64_not_in_null_false + ,0.3 NOT IN (0.0,0.1,0.2,NULL) AS float64_not_in_null_null + + ,'1' IN ('a','b',1) AS utf8_cast_in_true + ,'2' IN ('a','b',1) AS utf8_cast_in_false + ,'2' NOT IN ('a','b',1) AS utf8_cast_not_in_true + ,'1' NOT IN ('a','b',1) AS utf8_cast_not_in_false + ,NULL IN ('a','b',1) AS utf8_cast_in_null + ,NULL NOT IN ('a','b',1) AS utf8_cast_not_in_null + ,'1' IN ('a','b',NULL,1) AS utf8_cast_in_null_true + ,'2' IN ('a','b',NULL,1) AS utf8_cast_in_null_null + ,'1' NOT IN ('a','b',NULL,1) AS utf8_cast_not_in_null_false + ,'2' NOT IN ('a','b',NULL,1) AS utf8_cast_not_in_null_null + "; + let actual = execute(&mut ctx, sql).await; + + let expected = vec![vec![ + "true", "false", "true", "false", "NULL", "NULL", "true", "NULL", "false", + "NULL", "true", "false", "true", "false", "NULL", "NULL", "true", "NULL", + "false", "NULL", "true", "false", "true", "false", "NULL", "NULL", "true", + "NULL", "false", "NULL", "true", "false", "true", "false", "NULL", "NULL", + "true", "NULL", "false", "NULL", + ]]; + assert_eq!(expected, actual); + Ok(()) +}