Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 193 additions & 4 deletions datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@

//! Collection of utility functions that are leveraged by the query optimizer rules

use arrow::array::new_null_array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;

use super::optimizer::OptimizerRule;
use crate::execution::context::ExecutionProps;
use crate::execution::context::{ExecutionContextState, ExecutionProps};
use crate::logical_plan::{
build_join_schema, Column, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder,
Operator, Partitioning, Recursion,
build_join_schema, Column, DFSchema, DFSchemaRef, Expr, ExprRewriter, LogicalPlan,
LogicalPlanBuilder, Operator, Partitioning, Recursion,
};
use crate::physical_plan::functions::Volatility;
use crate::physical_plan::planner::DefaultPhysicalPlanner;
use crate::prelude::lit;
use crate::scalar::ScalarValue;
use crate::{
Expand Down Expand Up @@ -468,10 +474,144 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
}
}

/// Evaluates any sub expressions that are constants within `expr`.
///
/// For example, will rewrite `'foo' != bar OR col1 = 'baz'` to `false
/// OR col1 = 'baz'`
pub fn partially_evaluate_expr(expr: Expr) -> Result<Expr> {
let mut evaluator = ExprEvaluator::new();

expr.rewrite(&mut evaluator)
}

struct ExprEvaluator {
/// can_evaluate[N] represents the state of traversal when we are
/// N levels deep in the tree. when mutate is called (after
/// visiting all siblings) if can_evauate.top() is true, means there were no non-constants for any siblings
/// no non-constant values found in either this Expr or any
can_evaluate: Vec<bool>,

ctx_state: ExecutionContextState,
planner: DefaultPhysicalPlanner,
input_schema: DFSchema,
input_batch: RecordBatch,
}

impl ExprRewriter for ExprEvaluator {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
// check for reasons we can't evaluate this node
let self_ok_to_evaluate = match &expr {
Expr::Column(_) => false,
Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()),
Expr::ScalarUDF { fun, .. } => Self::volatility_ok(fun.signature.volatility),
_ => true,
};

// if this expr is not ok to evaluate, mark entire parent stack as not ok
if !self_ok_to_evaluate {
// walk back up stack, marking first parent that is not mutable
let mut parent_iter = self.can_evaluate.iter_mut().rev();
while let Some(p) = parent_iter.next() {
if !*p {
// optimization: if we find an element on the
// stack already marked, know all elements above are also marked
break;
}
*p = false;
}
}

// pre_visit pushed, can pop here
let ok_to_evaluate = self.can_evaluate.pop().unwrap();

if ok_to_evaluate {
let scalar = self.evaluate_to_scalar(expr)?;
Ok(Expr::Literal(scalar))
} else {
Ok(expr)
}
}

fn pre_visit(
&mut self,
_expr: &Expr,
) -> Result<crate::logical_plan::RewriteRecursion> {
// Default to being able to evaluate this node
self.can_evaluate.push(true);

Ok(crate::logical_plan::RewriteRecursion::Continue)
}
}

impl ExprEvaluator {
pub fn new() -> Self {
let planner = DefaultPhysicalPlanner::default();
let ctx_state = ExecutionContextState::new();
let input_schema = DFSchema::empty();

//The dummy column name shouldn't really matter as only scalar expressions will be evaluated
static DUMMY_COL_NAME: &str = ".";
let schema =
Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Float64, true)]);

let col = new_null_array(&DataType::Float64, 1);

let input_batch =
RecordBatch::try_new(std::sync::Arc::new(schema), vec![col]).unwrap();

Self {
can_evaluate: vec![],
ctx_state,
planner,
input_schema,
input_batch,
}
}

/// Can a fuction of the specified volatility be evaluated?
fn volatility_ok(volatility: Volatility) -> bool {
match volatility {
Volatility::Immutable => true,
Volatility::Stable => true,
Volatility::Volatile => false,
}
}

fn evaluate_to_scalar(&self, expr: Expr) -> Result<ScalarValue> {
if let Expr::Literal(s) = expr {
return Ok(s.clone());
}

let phys_expr = self.planner.create_physical_expr(
&expr,
&self.input_schema,
&self.input_batch.schema(),
&self.ctx_state,
)?;
let col_val = phys_expr.evaluate(&self.input_batch)?;
match col_val {
crate::physical_plan::ColumnarValue::Array(a) => {
if a.len() != 1 {
Err(DataFusionError::Execution(format!(
"Could not evaluate the expressison, found a result of length {}",
a.len()
)))
} else {
Ok(ScalarValue::try_from_array(&a, 0)?)
}
}
crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::col;
use crate::{
logical_plan::{col, lit_timestamp_nano},
physical_plan::functions::BuiltinScalarFunction,
};
use arrow::datatypes::DataType;
use std::collections::HashSet;

Expand All @@ -496,4 +636,53 @@ mod tests {
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}

#[test]
fn test_expr_evaluator() {
test_evaluate(lit(true), lit(true));
test_evaluate(lit(true).or(lit(true)), lit(true));
test_evaluate(lit(true).or(lit(false)), lit(true));

// "foo" == "foo"
test_evaluate(lit("foo").eq(lit("foo")), lit(true));
// "foo" != "foo"
test_evaluate(lit("foo").not_eq(lit("foo")), lit(false));

// c = 1
test_evaluate(col("c").eq(lit(1)), col("c").eq(lit(1)));
// c = 1 + 2 --> c + 3
test_evaluate(col("c").eq(lit(1) + lit(2)), col("c").eq(lit(3)));
test_evaluate(
(lit("foo").not_eq(lit("foo"))).or(col("c").eq(lit(1))),
lit(false).or(col("c").eq(lit(1))),
);

// test function evaluation
let to_timestamp = Expr::ScalarFunction {
args: vec![lit("foo"), lit("bar")],
fun: BuiltinScalarFunction::Concat,
};
test_evaluate(to_timestamp, lit("foobar"));

// test function evaluation
let to_timestamp = Expr::ScalarFunction {
args: vec![lit("2020-09-08T12:00:00+00:00")],
fun: BuiltinScalarFunction::ToTimestamp,
};
test_evaluate(to_timestamp, lit_timestamp_nano(1599566400000000000i64));

// TODO write some more tests for:
// to timestamp with col arguments
// now()
// volatile functions, etc (rand)
}

fn test_evaluate(input_expr: Expr, expected_expr: Expr) {
let evaluated_expr = partially_evaluate_expr(input_expr.clone()).unwrap();
assert_eq!(
evaluated_expr, expected_expr,
"Mismatch evaluating {}\n Expected:{}\n Got:{}",
input_expr, expected_expr, evaluated_expr
);
}
}