diff --git a/rust/arrow/src/compute/kernels/comparison.rs b/rust/arrow/src/compute/kernels/comparison.rs index d73356e44fe..4268eaf568f 100644 --- a/rust/arrow/src/compute/kernels/comparison.rs +++ b/rust/arrow/src/compute/kernels/comparison.rs @@ -155,6 +155,34 @@ pub fn like_utf8(left: &StringArray, right: &StringArray) -> Result::from(Arc::new(data))) } +pub fn like_utf8_scalar(left: &StringArray, right: &str) -> Result { + let null_bit_buffer = left.data().null_buffer().cloned(); + let re_pattern = right.replace("%", ".*").replace("_", "."); + let re = Regex::new(&re_pattern).map_err(|e| { + ArrowError::ComputeError(format!( + "Unable to build regex from LIKE pattern: {}", + e + )) + })?; + + let mut result = BooleanBufferBuilder::new(left.len()); + for i in 0..left.len() { + let haystack = left.value(i); + result.append(re.is_match(haystack))?; + } + + let data = ArrayData::new( + DataType::Boolean, + left.len(), + None, + null_bit_buffer, + 0, + vec![result.finish()], + vec![], + ); + Ok(PrimitiveArray::::from(Arc::new(data))) +} + pub fn nlike_utf8(left: &StringArray, right: &StringArray) -> Result { let mut map = HashMap::new(); if left.len() != right.len() { @@ -200,6 +228,34 @@ pub fn nlike_utf8(left: &StringArray, right: &StringArray) -> Result::from(Arc::new(data))) } +pub fn nlike_utf8_scalar(left: &StringArray, right: &str) -> Result { + let null_bit_buffer = left.data().null_buffer().cloned(); + let re_pattern = right.replace("%", ".*").replace("_", "."); + let re = Regex::new(&re_pattern).map_err(|e| { + ArrowError::ComputeError(format!( + "Unable to build regex from LIKE pattern: {}", + e + )) + })?; + + let mut result = BooleanBufferBuilder::new(left.len()); + for i in 0..left.len() { + let haystack = left.value(i); + result.append(!re.is_match(haystack))?; + } + + let data = ArrayData::new( + DataType::Boolean, + left.len(), + None, + null_bit_buffer, + 0, + vec![result.finish()], + vec![], + ); + Ok(PrimitiveArray::::from(Arc::new(data))) +} + pub fn eq_utf8(left: &StringArray, right: &StringArray) -> Result { compare_op!(left, right, |a, b| a == b) } @@ -1081,6 +1137,13 @@ mod tests { like_utf8, vec![true, true, true, false] ); + test_utf8_scalar!( + test_utf8_array_like_scalar, + vec!["arrow", "parquet", "datafusion", "flight"], + "%ar%", + like_utf8_scalar, + vec![true, true, false, false] + ); test_utf8!( test_utf8_array_nlike, vec!["arrow", "arrow", "arrow", "arrow"], @@ -1088,6 +1151,13 @@ mod tests { nlike_utf8, vec![false, false, false, true] ); + test_utf8_scalar!( + test_utf8_array_nlike_scalar, + vec!["arrow", "parquet", "datafusion", "flight"], + "%ar%", + nlike_utf8_scalar, + vec![false, false, true, true] + ); test_utf8!( test_utf8_array_eq, diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml index 8af21d4582f..933710aed3c 100644 --- a/rust/datafusion/Cargo.toml +++ b/rust/datafusion/Cargo.toml @@ -78,3 +78,7 @@ harness = false [[bench]] name = "math_query_sql" harness = false + +[[bench]] +name = "filter_query_sql" +harness = false diff --git a/rust/datafusion/benches/filter_query_sql.rs b/rust/datafusion/benches/filter_query_sql.rs new file mode 100644 index 00000000000..7c7f6f887ce --- /dev/null +++ b/rust/datafusion/benches/filter_query_sql.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{Float32Array, Float64Array}, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, +}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::prelude::ExecutionContext; +use datafusion::{datasource::MemTable, error::Result}; +use futures::executor::block_on; +use std::sync::Arc; + +async fn query(ctx: &mut ExecutionContext, sql: &str) { + // execute the query + let df = ctx.sql(&sql).unwrap(); + let results = df.collect().await.unwrap(); + + // display the relation + for _batch in results { + // println!("num_rows: {}", _batch.num_rows()); + } +} + +fn create_context(array_len: usize, batch_size: usize) -> Result { + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("f32", DataType::Float32, false), + Field::new("f64", DataType::Float64, false), + ])); + + // define data. + let batches = (0..array_len / batch_size) + .map(|i| { + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![i as f32; batch_size])), + Arc::new(Float64Array::from(vec![i as f64; batch_size])), + ], + ) + .unwrap() + }) + .collect::>(); + + let mut ctx = ExecutionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::new(schema, vec![batches])?; + ctx.register_table("t", Box::new(provider)); + + Ok(ctx) +} + +fn criterion_benchmark(c: &mut Criterion) { + let array_len = 524_288; // 2^19 + let batch_size = 4096; // 2^12 + + c.bench_function("filter_array", |b| { + let mut ctx = create_context(array_len, batch_size).unwrap(); + b.iter(|| block_on(query(&mut ctx, "select f32, f64 from t where f32 >= f64"))) + }); + + c.bench_function("filter_scalar", |b| { + let mut ctx = create_context(array_len, batch_size).unwrap(); + b.iter(|| { + block_on(query( + &mut ctx, + "select f32, f64 from t where f32 >= 250 and f64 > 250", + )) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index ce3d038f726..ce8cc36afcb 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -21,31 +21,36 @@ use std::convert::TryFrom; use std::fmt; use std::sync::Arc; +use super::ColumnarValue; use crate::error::{DataFusionError, Result}; use crate::logical_plan::Operator; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; use crate::scalar::ScalarValue; -use arrow::array::{ - Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, - Int8Builder, LargeStringArray, StringBuilder, UInt16Builder, UInt32Builder, - UInt64Builder, UInt8Builder, -}; +use arrow::array::LargeStringArray; use arrow::compute; use arrow::compute::kernels; use arrow::compute::kernels::arithmetic::{add, divide, multiply, subtract}; use arrow::compute::kernels::boolean::{and, or}; use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow::compute::kernels::comparison::{ - eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8, nlike_utf8, + eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar, +}; +use arrow::compute::kernels::comparison::{ + eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, like_utf8_scalar, lt_eq_utf8, lt_utf8, + neq_utf8, nlike_utf8, nlike_utf8_scalar, +}; +use arrow::compute::kernels::comparison::{ + eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, lt_eq_utf8_scalar, lt_utf8_scalar, + neq_utf8_scalar, }; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; -use arrow::datatypes::{DataType, Schema, TimeUnit}; +use arrow::datatypes::{DataType, DateUnit, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; use arrow::{ array::{ - ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, StringArray, TimestampNanosecondArray, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, + ArrayRef, BooleanArray, Date32Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, StringArray, TimestampNanosecondArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, datatypes::Field, }; @@ -92,8 +97,10 @@ impl PhysicalExpr for Column { } /// Evaluate the expression - fn evaluate(&self, batch: &RecordBatch) -> Result { - Ok(batch.column(batch.schema().index_of(&self.name)?).clone()) + fn evaluate(&self, batch: &RecordBatch) -> Result { + Ok(ColumnarValue::Array( + batch.column(batch.schema().index_of(&self.name)?).clone(), + )) } } @@ -964,6 +971,44 @@ macro_rules! compute_utf8_op { }}; } +/// Invoke a compute kernel on a data array and a scalar value +macro_rules! compute_utf8_op_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + if let ScalarValue::Utf8(Some(string_value)) = $RIGHT { + Ok(Arc::new(paste::expr! {[<$OP _utf8_scalar>]}( + &ll, + &string_value, + )?)) + } else { + Err(DataFusionError::Internal(format!( + "compute_utf8_op_scalar failed to cast literal value {}", + $RIGHT + ))) + } + }}; +} + +/// Invoke a compute kernel on a data array and a scalar value +macro_rules! compute_op_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + use std::convert::TryInto; + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + // generate the scalar function name, such as lt_scalar, from the $OP parameter + // (which could have a value of lt) and the suffix _scalar + Ok(Arc::new(paste::expr! {[<$OP _scalar>]}( + &ll, + $RIGHT.try_into()?, + )?)) + }}; +} + /// Invoke a compute kernel on a pair of arrays macro_rules! compute_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -979,6 +1024,19 @@ macro_rules! compute_op { }}; } +macro_rules! binary_string_array_op_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + let result = match $LEFT.data_type() { + DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?}", + other + ))), + }; + Some(result) + }}; +} + macro_rules! binary_string_array_op { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ match $LEFT.data_type() { @@ -1015,6 +1073,37 @@ macro_rules! binary_primitive_array_op { }}; } +/// The binary_array_op_scalar macro includes types that extend beyond the primitive, +/// such as Utf8 strings. +macro_rules! binary_array_op_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + let result = match $LEFT.data_type() { + DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array), + DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array), + DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array), + DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array), + DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array), + DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array), + DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array), + DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array), + DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array), + DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array), + DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray), + DataType::Timestamp(TimeUnit::Nanosecond, None) => { + compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray) + } + DataType::Date32(DateUnit::Day) => { + compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?}", + other + ))), + }; + Some(result) + }}; +} + /// The binary_array_op macro includes types that extend beyond the primitive, /// such as Utf8 strings. macro_rules! binary_array_op { @@ -1034,6 +1123,9 @@ macro_rules! binary_array_op { DataType::Timestamp(TimeUnit::Nanosecond, None) => { compute_op!($LEFT, $RIGHT, $OP, TimestampNanosecondArray) } + DataType::Date32(DateUnit::Day) => { + compute_op!($LEFT, $RIGHT, $OP, Date32Array) + } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?}", other @@ -1316,19 +1408,75 @@ impl PhysicalExpr for BinaryExpr { Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?) } - fn evaluate(&self, batch: &RecordBatch) -> Result { - let left = self.left.evaluate(batch)?; - let right = self.right.evaluate(batch)?; - if left.data_type() != right.data_type() { - // this should have been captured during planning + fn evaluate(&self, batch: &RecordBatch) -> Result { + let left_value = self.left.evaluate(batch)?; + let right_value = self.right.evaluate(batch)?; + let left_data_type = left_value.data_type(); + let right_data_type = right_value.data_type(); + + if left_data_type != right_data_type { return Err(DataFusionError::Internal(format!( "Cannot evaluate binary expression {:?} with types {:?} and {:?}", - self.op, - left.data_type(), - right.data_type() + self.op, left_data_type, right_data_type ))); } - match &self.op { + + let scalar_result = match (&left_value, &right_value) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { + // if left is array and right is literal - use scalar operations + match &self.op { + Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), lt), + Operator::LtEq => { + binary_array_op_scalar!(array, scalar.clone(), lt_eq) + } + Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), gt), + Operator::GtEq => { + binary_array_op_scalar!(array, scalar.clone(), gt_eq) + } + Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq), + Operator::NotEq => { + binary_array_op_scalar!(array, scalar.clone(), neq) + } + Operator::Like => { + binary_string_array_op_scalar!(array, scalar.clone(), like) + } + Operator::NotLike => { + binary_string_array_op_scalar!(array, scalar.clone(), nlike) + } + // if scalar operation is not supported - fallback to array implementation + _ => None, + } + } + (ColumnarValue::Scalar(scalar), ColumnarValue::Array(array)) => { + // if right is literal and left is array - reverse operator and parameters + match &self.op { + Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), gt), + Operator::LtEq => { + binary_array_op_scalar!(array, scalar.clone(), gt_eq) + } + Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), lt), + Operator::GtEq => { + binary_array_op_scalar!(array, scalar.clone(), lt_eq) + } + Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq), + Operator::NotEq => { + binary_array_op_scalar!(array, scalar.clone(), neq) + } + // if scalar operation is not supported - fallback to array implementation + _ => None, + } + } + (_, _) => None, + }; + + if let Some(result) = scalar_result { + return result.map(|a| ColumnarValue::Array(a)); + } + + // if both arrays or both literals - extract arrays and continue execution + let (left, right) = (left_value.into_array(batch), right_value.into_array(batch)); + + let result: Result = match &self.op { Operator::Like => binary_string_array_op!(left, right, like), Operator::NotLike => binary_string_array_op!(left, right, nlike), Operator::Lt => binary_array_op!(left, right, lt), @@ -1342,7 +1490,7 @@ impl PhysicalExpr for BinaryExpr { Operator::Multiply => binary_primitive_array_op!(left, right, multiply), Operator::Divide => binary_primitive_array_op!(left, right, divide), Operator::And => { - if left.data_type() == &DataType::Boolean { + if left_data_type == DataType::Boolean { boolean_op!(left, right, and) } else { return Err(DataFusionError::Internal(format!( @@ -1354,21 +1502,20 @@ impl PhysicalExpr for BinaryExpr { } } Operator::Or => { - if left.data_type() == &DataType::Boolean { + if left_data_type == DataType::Boolean { boolean_op!(left, right, or) } else { return Err(DataFusionError::Internal(format!( "Cannot evaluate binary expression {:?} with types {:?} and {:?}", - self.op, - left.data_type(), - right.data_type() + self.op, left_data_type, right_data_type ))); } } Operator::Modulus => Err(DataFusionError::NotImplemented( "Modulus operator is still not supported".to_string(), )), - } + }; + result.map(|a| ColumnarValue::Array(a)) } } @@ -1403,6 +1550,7 @@ impl fmt::Display for NotExpr { write!(f, "NOT {}", self.arg) } } + impl PhysicalExpr for NotExpr { fn data_type(&self, _input_schema: &Schema) -> Result { return Ok(DataType::Boolean); @@ -1412,13 +1560,27 @@ impl PhysicalExpr for NotExpr { self.arg.nullable(input_schema) } - fn evaluate(&self, batch: &RecordBatch) -> Result { + fn evaluate(&self, batch: &RecordBatch) -> Result { let arg = self.arg.evaluate(batch)?; - let arg = arg - .as_any() - .downcast_ref::() - .expect("boolean_op failed to downcast array"); - return Ok(Arc::new(arrow::compute::kernels::boolean::not(arg)?)); + match arg { + ColumnarValue::Array(array) => { + let array = array.as_any().downcast_ref::().ok_or( + DataFusionError::Internal( + "boolean_op failed to downcast array".to_owned(), + ), + )?; + Ok(ColumnarValue::Array(Arc::new( + arrow::compute::kernels::boolean::not(array)?, + ))) + } + ColumnarValue::Scalar(scalar) => { + use std::convert::TryInto; + let bool_value: bool = scalar.try_into()?; + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + !bool_value, + )))) + } + } } } @@ -1472,9 +1634,16 @@ impl PhysicalExpr for IsNullExpr { Ok(false) } - fn evaluate(&self, batch: &RecordBatch) -> Result { + fn evaluate(&self, batch: &RecordBatch) -> Result { let arg = self.arg.evaluate(batch)?; - return Ok(Arc::new(arrow::compute::is_null(&arg)?)); + match arg { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new( + arrow::compute::is_null(&array)?, + ))), + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( + ScalarValue::Boolean(Some(scalar.is_null())), + )), + } } } @@ -1510,9 +1679,16 @@ impl PhysicalExpr for IsNotNullExpr { Ok(false) } - fn evaluate(&self, batch: &RecordBatch) -> Result { + fn evaluate(&self, batch: &RecordBatch) -> Result { let arg = self.arg.evaluate(batch)?; - return Ok(Arc::new(arrow::compute::is_not_null(&arg)?)); + match arg { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new( + arrow::compute::is_not_null(&array)?, + ))), + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( + ScalarValue::Boolean(Some(!scalar.is_null())), + )), + } } } @@ -1555,9 +1731,20 @@ impl PhysicalExpr for CastExpr { self.expr.nullable(input_schema) } - fn evaluate(&self, batch: &RecordBatch) -> Result { + fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; - Ok(kernels::cast::cast(&value, &self.cast_type)?) + match value { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(kernels::cast::cast( + &array, + &self.cast_type, + )?)), + ColumnarValue::Scalar(scalar) => { + let scalar_array = scalar.to_array(); + let cast_array = kernels::cast::cast(&scalar_array, &self.cast_type)?; + let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; + Ok(ColumnarValue::Scalar(cast_scalar)) + } + } } } @@ -1596,24 +1783,6 @@ impl Literal { } } -/// Build array containing the same literal value repeated. This is necessary because the Arrow -/// memory model does not have the concept of a scalar value currently. -macro_rules! build_literal_array { - ($BATCH:ident, $BUILDER:ident, $VALUE:expr) => {{ - let mut builder = $BUILDER::new($BATCH.num_rows()); - if $VALUE.is_none() { - for _ in 0..$BATCH.num_rows() { - builder.append_null()?; - } - } else { - for _ in 0..$BATCH.num_rows() { - builder.append_value($VALUE.unwrap())?; - } - } - Ok(Arc::new(builder.finish())) - }}; -} - impl fmt::Display for Literal { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.value) @@ -1629,46 +1798,8 @@ impl PhysicalExpr for Literal { Ok(self.value.is_null()) } - fn evaluate(&self, batch: &RecordBatch) -> Result { - match &self.value { - ScalarValue::Int8(value) => build_literal_array!(batch, Int8Builder, *value), - ScalarValue::Int16(value) => { - build_literal_array!(batch, Int16Builder, *value) - } - ScalarValue::Int32(value) => { - build_literal_array!(batch, Int32Builder, *value) - } - ScalarValue::Int64(value) => { - build_literal_array!(batch, Int64Builder, *value) - } - ScalarValue::UInt8(value) => { - build_literal_array!(batch, UInt8Builder, *value) - } - ScalarValue::UInt16(value) => { - build_literal_array!(batch, UInt16Builder, *value) - } - ScalarValue::UInt32(value) => { - build_literal_array!(batch, UInt32Builder, *value) - } - ScalarValue::UInt64(value) => { - build_literal_array!(batch, UInt64Builder, *value) - } - ScalarValue::Float32(value) => { - build_literal_array!(batch, Float32Builder, *value) - } - ScalarValue::Float64(value) => { - build_literal_array!(batch, Float64Builder, *value) - } - ScalarValue::Utf8(value) => build_literal_array!( - batch, - StringBuilder, - value.as_ref().and_then(|e| Some(&*e)) - ), - other => Err(DataFusionError::Internal(format!( - "Unsupported literal type {:?}", - other - ))), - } + fn evaluate(&self, _batch: &RecordBatch) -> Result { + Ok(ColumnarValue::Scalar(self.value.clone())) } } @@ -1689,8 +1820,18 @@ pub struct PhysicalSortExpr { impl PhysicalSortExpr { /// evaluate the sort expression into SortColumn that can be passed into arrow sort kernel pub fn evaluate_to_sort_column(&self, batch: &RecordBatch) -> Result { + let value_to_sort = self.expr.evaluate(batch)?; + let array_to_sort = match value_to_sort { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => { + return Err(DataFusionError::Internal(format!( + "Sort operation is not applicable to scalar value {}", + scalar + ))); + } + }; Ok(SortColumn { - values: self.expr.evaluate(batch)?, + values: array_to_sort, options: Some(self.options), }) } @@ -1734,7 +1875,7 @@ mod tests { // expression: "a < b" let lt = binary_simple(col("a"), Operator::Lt, col("b")); - let result = lt.evaluate(&batch)?; + let result = lt.evaluate(&batch)?.into_array(&batch); assert_eq!(result.len(), 5); let expected = vec![false, false, true, true, true]; @@ -1770,7 +1911,7 @@ mod tests { ); assert_eq!("a < b OR a = b", format!("{}", expr)); - let result = expr.evaluate(&batch)?; + let result = expr.evaluate(&batch)?.into_array(&batch); assert_eq!(result.len(), 5); let expected = vec![true, true, false, true, false]; @@ -1796,7 +1937,7 @@ mod tests { let literal_expr = lit(ScalarValue::from(42i32)); assert_eq!("42", format!("{}", literal_expr)); - let literal_array = literal_expr.evaluate(&batch)?; + let literal_array = literal_expr.evaluate(&batch)?.into_array(&batch); let literal_array = literal_array.as_any().downcast_ref::().unwrap(); // note that the contents of the literal array are unrelated to the batch contents except for the length of the array @@ -1835,7 +1976,7 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $C_TYPE); // compute - let result = expression.evaluate(&batch)?; + let result = expression.evaluate(&batch)?.into_array(&batch); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $C_TYPE); @@ -1955,7 +2096,7 @@ mod tests { // build dictionary let keys_builder = PrimitiveBuilder::::new(10); - let values_builder = StringBuilder::new(10); + let values_builder = arrow::array::StringBuilder::new(10); let mut dict_builder = StringDictionaryBuilder::new(keys_builder, values_builder); dict_builder.append("one")?; @@ -1986,7 +2127,7 @@ mod tests { assert_eq!(expression.data_type(&schema)?, DataType::Boolean); // evaluate and verify the result type matched - let result = expression.evaluate(&batch)?; + let result = expression.evaluate(&batch)?.into_array(&batch); assert_eq!(result.data_type(), &DataType::Boolean); // verify that the result itself is correct @@ -2000,7 +2141,7 @@ mod tests { assert_eq!(expression.data_type(&schema)?, DataType::Boolean); // evaluate and verify the result type matched - let result = expression.evaluate(&batch)?; + let result = expression.evaluate(&batch)?.into_array(&batch); assert_eq!(result.data_type(), &DataType::Boolean); // verify that the result itself is correct @@ -2056,7 +2197,7 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?; + let result = expression.evaluate(&batch)?.into_array(&batch); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -2613,6 +2754,7 @@ mod tests { let values = expr .iter() .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch))) .collect::>>()?; accum.update_batch(&values)?; accum.evaluate() @@ -2710,7 +2852,7 @@ mod tests { ) -> Result<()> { let arithmetic_op = binary_simple(col("a"), op, col("b")); let batch = RecordBatch::try_new(schema, data)?; - let result = arithmetic_op.evaluate(&batch)?; + let result = arithmetic_op.evaluate(&batch)?.into_array(&batch); assert_array_eq::(expected, result); @@ -2745,7 +2887,7 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; - let result = expr.evaluate(&batch)?; + let result = expr.evaluate(&batch)?.into_array(&batch); let result = result .as_any() .downcast_ref::() @@ -2774,7 +2916,7 @@ mod tests { // expression: "a is null" let expr = is_null(col("a")).unwrap(); - let result = expr.evaluate(&batch)?; + let result = expr.evaluate(&batch)?.into_array(&batch); let result = result .as_any() .downcast_ref::() @@ -2795,7 +2937,7 @@ mod tests { // expression: "a is not null" let expr = is_not_null(col("a")).unwrap(); - let result = expr.evaluate(&batch)?; + let result = expr.evaluate(&batch)?.into_array(&batch); let result = result .as_any() .downcast_ref::() diff --git a/rust/datafusion/src/physical_plan/filter.rs b/rust/datafusion/src/physical_plan/filter.rs index 4a61d7d9dac..7c3da888c04 100644 --- a/rust/datafusion/src/physical_plan/filter.rs +++ b/rust/datafusion/src/physical_plan/filter.rs @@ -27,7 +27,7 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr}; use arrow::array::BooleanArray; -use arrow::compute::filter; +use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; @@ -128,6 +128,7 @@ fn batch_filter( ) -> ArrowResult { predicate .evaluate(&batch) + .map(|v| v.into_array(batch)) .map_err(DataFusionError::into_arrow_external_error) .and_then(|array| { array @@ -139,17 +140,9 @@ fn batch_filter( ) .into_arrow_external_error(), ) - // apply predicate to each column - .and_then(|predicate| { - batch - .columns() - .iter() - .map(|column| filter(column.as_ref(), predicate)) - .collect::>>() - }) + // apply filter array to record batch + .and_then(|filter_array| filter_record_batch(batch, filter_array)) }) - // build RecordBatch - .and_then(|columns| RecordBatch::try_new(batch.schema().clone(), columns)) } impl Stream for FilterExecStream { diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index d0a121139d8..12402ec90b0 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -31,7 +31,7 @@ use super::{ type_coercion::{coerce, data_types}, - PhysicalExpr, + ColumnarValue, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; use crate::physical_plan::array_expressions; @@ -343,17 +343,17 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(true) } - fn evaluate(&self, batch: &RecordBatch) -> Result { + fn evaluate(&self, batch: &RecordBatch) -> Result { // evaluate the arguments let inputs = self .args .iter() - .map(|e| e.evaluate(batch)) + .map(|e| e.evaluate(batch).map(|v| v.into_array(batch))) .collect::>>()?; // evaluate the function let fun = self.fun.as_ref(); - (fun)(&inputs) + (fun)(&inputs).map(|a| ColumnarValue::Array(a)) } } @@ -381,8 +381,8 @@ mod tests { assert_eq!(expr.data_type(&schema)?, DataType::Float64); // evaluate works - let result = - expr.evaluate(&RecordBatch::try_new(Arc::new(schema.clone()), columns)?)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + let result = expr.evaluate(&batch)?.into_array(&batch); // downcast works let result = result.as_any().downcast_ref::().unwrap(); @@ -422,8 +422,8 @@ mod tests { assert_eq!(expr.data_type(&schema)?, DataType::Utf8); // evaluate works - let result = - expr.evaluate(&RecordBatch::try_new(Arc::new(schema.clone()), columns)?)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + let result = expr.evaluate(&batch)?.into_array(&batch); // downcast works let result = result.as_any().downcast_ref::().unwrap(); @@ -475,8 +475,8 @@ mod tests { ); // evaluate works - let result = - expr.evaluate(&RecordBatch::try_new(Arc::new(schema.clone()), columns)?)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + let result = expr.evaluate(&batch)?.into_array(&batch); // downcast works let result = result diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 97a3d5ca6da..21aeefcea35 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -433,6 +433,7 @@ fn evaluate( ) -> Result> { expr.iter() .map(|expr| expr.evaluate(&batch)) + .map(|r| r.map(|v| v.into_array(batch))) .collect::>>() } @@ -561,6 +562,7 @@ fn aggregate_batch( let values = &expr .iter() .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch))) .collect::>>()?; // 1.3 diff --git a/rust/datafusion/src/physical_plan/mod.rs b/rust/datafusion/src/physical_plan/mod.rs index a2bc3bef7f4..0a9711ac8aa 100644 --- a/rust/datafusion/src/physical_plan/mod.rs +++ b/rust/datafusion/src/physical_plan/mod.rs @@ -110,6 +110,30 @@ pub enum Distribution { SinglePartition, } +/// Represents the result from an expression +pub enum ColumnarValue { + /// Array of values + Array(ArrayRef), + /// A single value + Scalar(ScalarValue), +} + +impl ColumnarValue { + fn data_type(&self) -> DataType { + match self { + ColumnarValue::Array(array_value) => array_value.data_type().clone(), + ColumnarValue::Scalar(scalar_value) => scalar_value.get_datatype(), + } + } + + fn into_array(self, batch: &RecordBatch) -> ArrayRef { + match self { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows()), + } + } +} + /// Expression that can be evaluated against a RecordBatch /// A Physical expression knows its type, nullability and how to evaluate itself. pub trait PhysicalExpr: Send + Sync + Display + Debug { @@ -118,7 +142,7 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug { /// Determine whether this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result; /// Evaluate an expression against a RecordBatch - fn evaluate(&self, batch: &RecordBatch) -> Result; + fn evaluate(&self, batch: &RecordBatch) -> Result; } /// An aggregate expression that: diff --git a/rust/datafusion/src/physical_plan/projection.rs b/rust/datafusion/src/physical_plan/projection.rs index 65b2828ca9a..a3329041aea 100644 --- a/rust/datafusion/src/physical_plan/projection.rs +++ b/rust/datafusion/src/physical_plan/projection.rs @@ -130,6 +130,7 @@ fn batch_project( expressions .iter() .map(|expr| expr.evaluate(&batch)) + .map(|r| r.map(|v| v.into_array(batch))) .collect::>>() .map_or_else( |e| Err(DataFusionError::into_arrow_external_error(e)), diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs index 2eb1d69a617..06309ab84c0 100644 --- a/rust/datafusion/src/scalar.rs +++ b/rust/datafusion/src/scalar.rs @@ -19,11 +19,6 @@ use std::{convert::TryFrom, fmt, sync::Arc}; -use arrow::array::{ - Array, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, LargeStringArray, ListArray, StringArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, -}; use arrow::array::{ Int16Builder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, @@ -32,6 +27,14 @@ use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; +use arrow::{ + array::{ + Array, BooleanArray, Date32Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeStringArray, ListArray, StringArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }, + datatypes::DateUnit, +}; use crate::error::{DataFusionError, Result}; @@ -67,6 +70,8 @@ pub enum ScalarValue { LargeUtf8(Option), /// list of nested ScalarValue List(Option>, DataType), + /// Date stored as a signed 32bit int + Date32(Option), } macro_rules! typed_cast { @@ -80,29 +85,33 @@ macro_rules! typed_cast { } macro_rules! build_list { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr) => {{ + ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ match $VALUES { None => { let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new(0)); - builder.append(false).unwrap(); + for _ in 0..$SIZE { + builder.append(false).unwrap(); + } builder.finish() } Some(values) => { let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new(values.len())); - for scalar_value in values { - match scalar_value { - ScalarValue::$SCALAR_TY(Some(v)) => { - builder.values().append_value(*v).unwrap() - } - ScalarValue::$SCALAR_TY(None) => { - builder.values().append_null().unwrap(); - } - _ => panic!("Incompatible ScalarValue for list"), - }; + for _ in 0..$SIZE { + for scalar_value in values { + match scalar_value { + ScalarValue::$SCALAR_TY(Some(v)) => { + builder.values().append_value(*v).unwrap() + } + ScalarValue::$SCALAR_TY(None) => { + builder.values().append_null().unwrap(); + } + _ => panic!("Incompatible ScalarValue for list"), + }; + } + builder.append(true).unwrap(); } - builder.append(true).unwrap(); builder.finish() } } @@ -129,6 +138,7 @@ impl ScalarValue { ScalarValue::List(_, data_type) => { DataType::List(Box::new(Field::new("item", data_type.clone(), true))) } + ScalarValue::Date32(_) => DataType::Date32(DateUnit::Day), } } @@ -155,33 +165,43 @@ impl ScalarValue { /// Converts a scalar value into an 1-row array. pub fn to_array(&self) -> ArrayRef { + self.to_array_of_size(1) + } + + /// Converts a scalar value into an array of `size` rows. + pub fn to_array_of_size(&self, size: usize) -> ArrayRef { match self { - ScalarValue::Boolean(e) => Arc::new(BooleanArray::from(vec![*e])) as ArrayRef, - ScalarValue::Float64(e) => Arc::new(Float64Array::from(vec![*e])) as ArrayRef, - ScalarValue::Float32(e) => Arc::new(Float32Array::from(vec![*e])), - ScalarValue::Int8(e) => Arc::new(Int8Array::from(vec![*e])), - ScalarValue::Int16(e) => Arc::new(Int16Array::from(vec![*e])), - ScalarValue::Int32(e) => Arc::new(Int32Array::from(vec![*e])), - ScalarValue::Int64(e) => Arc::new(Int64Array::from(vec![*e])), - ScalarValue::UInt8(e) => Arc::new(UInt8Array::from(vec![*e])), - ScalarValue::UInt16(e) => Arc::new(UInt16Array::from(vec![*e])), - ScalarValue::UInt32(e) => Arc::new(UInt32Array::from(vec![*e])), - ScalarValue::UInt64(e) => Arc::new(UInt64Array::from(vec![*e])), - ScalarValue::Utf8(e) => Arc::new(StringArray::from(vec![e.as_deref()])), + ScalarValue::Boolean(e) => { + Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef + } + ScalarValue::Float64(e) => { + Arc::new(Float64Array::from(vec![*e; size])) as ArrayRef + } + ScalarValue::Float32(e) => Arc::new(Float32Array::from(vec![*e; size])), + ScalarValue::Int8(e) => Arc::new(Int8Array::from(vec![*e; size])), + ScalarValue::Int16(e) => Arc::new(Int16Array::from(vec![*e; size])), + ScalarValue::Int32(e) => Arc::new(Int32Array::from(vec![*e; size])), + ScalarValue::Int64(e) => Arc::new(Int64Array::from(vec![*e; size])), + ScalarValue::UInt8(e) => Arc::new(UInt8Array::from(vec![*e; size])), + ScalarValue::UInt16(e) => Arc::new(UInt16Array::from(vec![*e; size])), + ScalarValue::UInt32(e) => Arc::new(UInt32Array::from(vec![*e; size])), + ScalarValue::UInt64(e) => Arc::new(UInt64Array::from(vec![*e; size])), + ScalarValue::Utf8(e) => Arc::new(StringArray::from(vec![e.as_deref(); size])), ScalarValue::LargeUtf8(e) => { - Arc::new(LargeStringArray::from(vec![e.as_deref()])) + Arc::new(LargeStringArray::from(vec![e.as_deref(); size])) } ScalarValue::List(values, data_type) => Arc::new(match data_type { - DataType::Int8 => build_list!(Int8Builder, Int8, values), - DataType::Int16 => build_list!(Int16Builder, Int16, values), - DataType::Int32 => build_list!(Int32Builder, Int32, values), - DataType::Int64 => build_list!(Int64Builder, Int64, values), - DataType::UInt8 => build_list!(UInt8Builder, UInt8, values), - DataType::UInt16 => build_list!(UInt16Builder, UInt16, values), - DataType::UInt32 => build_list!(UInt32Builder, UInt32, values), - DataType::UInt64 => build_list!(UInt64Builder, UInt64, values), + DataType::Int8 => build_list!(Int8Builder, Int8, values, size), + DataType::Int16 => build_list!(Int16Builder, Int16, values, size), + DataType::Int32 => build_list!(Int32Builder, Int32, values, size), + DataType::Int64 => build_list!(Int64Builder, Int64, values, size), + DataType::UInt8 => build_list!(UInt8Builder, UInt8, values, size), + DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size), + DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size), + DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size), _ => panic!("Unexpected DataType for list"), }), + ScalarValue::Date32(e) => Arc::new(Date32Array::from(vec![*e; size])), } } @@ -217,6 +237,9 @@ impl ScalarValue { }; ScalarValue::List(value, nested_type.data_type().clone()) } + DataType::Date32(DateUnit::Day) => { + typed_cast!(array, index, Date32Array, Date32) + } other => { return Err(DataFusionError::NotImplemented(format!( "Can't create a scalar of array of type \"{:?}\"", @@ -293,6 +316,54 @@ impl From for ScalarValue { } } +macro_rules! impl_try_from { + ($SCALAR:ident, $NATIVE:ident) => { + impl TryFrom for $NATIVE { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::$SCALAR(Some(inner_value)) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } + } + }; +} + +impl_try_from!(Int8, i8); +impl_try_from!(Int16, i16); + +// special implementation for i32 because of Date32 +impl TryFrom for i32 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Int32(Some(inner_value)) + | ScalarValue::Date32(Some(inner_value)) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + +impl_try_from!(Int64, i64); +impl_try_from!(UInt8, u8); +impl_try_from!(UInt16, u16); +impl_try_from!(UInt32, u32); +impl_try_from!(UInt64, u64); +impl_try_from!(Float32, f32); +impl_try_from!(Float64, f64); +impl_try_from!(Boolean, bool); + impl TryFrom<&DataType> for ScalarValue { type Error = DataFusionError; @@ -360,6 +431,7 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, + ScalarValue::Date32(e) => format_option!(f, e)?, }; Ok(()) } @@ -382,6 +454,7 @@ impl fmt::Debug for ScalarValue { ScalarValue::Utf8(_) => write!(f, "Utf8(\"{}\")", self), ScalarValue::LargeUtf8(_) => write!(f, "LargeUtf8(\"{}\")", self), ScalarValue::List(_, _) => write!(f, "List([{}])", self), + ScalarValue::Date32(_) => write!(f, "Date32(\"{}\")", self), } } } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index ca8a6cc4f0c..5435bc0a885 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1418,3 +1418,28 @@ async fn query_without_from() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn query_scalar_minus_array() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + None, + Some(3), + ]))], + )?; + + let table = MemTable::new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Box::new(table)); + let sql = "SELECT 4 - c1 FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["4"], vec!["3"], vec!["NULL"], vec!["1"]]; + assert_eq!(expected, actual); + Ok(()) +}