Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rust/datafusion/examples/memory_table_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::util::pretty;
use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::execution::context::ExecutionContext;
use datafusion::logicalplan::{Expr, ScalarValue};
use datafusion::logicalplan::lit;

/// This example demonstrates basic uses of the Table API on an in-memory table
fn main() -> Result<()> {
Expand Down Expand Up @@ -54,7 +54,7 @@ fn main() -> Result<()> {
let t = ctx.table("t")?;

// construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL
let filter = t.col("b")?.eq(&Expr::Literal(ScalarValue::Int32(10)));
let filter = t.col("b")?.eq(&lit(10));

let t = t.select_columns(vec!["a", "b"])?.filter(filter)?;

Expand Down
51 changes: 46 additions & 5 deletions rust/datafusion/src/logicalplan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,50 @@ pub fn col(name: &str) -> Expr {
Expr::UnresolvedColumn(name.to_owned())
}

/// Create a literal string expression
pub fn lit_str(str: &str) -> Expr {
Expr::Literal(ScalarValue::Utf8(str.to_owned()))
/// Whether it can be represented as a literal expression
pub trait Literal {
/// convert the value to a Literal expression
fn lit(&self) -> Expr;
}

impl Literal for &str {
fn lit(&self) -> Expr {
Expr::Literal(ScalarValue::Utf8((*self).to_owned()))
}
}

impl Literal for String {
fn lit(&self) -> Expr {
Expr::Literal(ScalarValue::Utf8((*self).to_owned()))
}
}

macro_rules! make_literal {
($TYPE:ty, $SCALAR:ident) => {
#[allow(missing_docs)]
impl Literal for $TYPE {
fn lit(&self) -> Expr {
Expr::Literal(ScalarValue::$SCALAR(self.clone()))
}
}
};
}

make_literal!(bool, Boolean);
make_literal!(f32, Float32);
make_literal!(f64, Float64);
make_literal!(i8, Int8);
make_literal!(i16, Int16);
make_literal!(i32, Int32);
make_literal!(i64, Int64);
make_literal!(u8, UInt8);
make_literal!(u16, UInt16);
make_literal!(u32, UInt32);
make_literal!(u64, UInt64);

/// Create a literal expression
pub fn lit<T: Literal>(n: T) -> Expr {
n.lit()
}

/// Create an convenience function representing a unary scalar function
Expand Down Expand Up @@ -965,7 +1006,7 @@ mod tests {
&employee_schema(),
Some(vec![0, 3]),
)?
.filter(col("state").eq(&lit_str("CO")))?
.filter(col("state").eq(&lit("CO")))?
.project(vec![col("id")])?
.build()?;

Expand All @@ -985,7 +1026,7 @@ mod tests {
CsvReadOptions::new().schema(&employee_schema()),
Some(vec![0, 3]),
)?
.filter(col("state").eq(&lit_str("CO")))?
.filter(col("state").eq(&lit("CO")))?
.project(vec![col("id")])?
.build()?;

Expand Down
7 changes: 2 additions & 5 deletions rust/datafusion/src/optimizer/projection_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ fn get_projected_schema(
mod tests {

use super::*;
use crate::logicalplan::lit;
use crate::logicalplan::Expr::*;
use crate::logicalplan::ScalarValue;
use crate::test::*;
use arrow::datatypes::DataType;

Expand Down Expand Up @@ -498,10 +498,7 @@ mod tests {
fn table_scan_with_literal_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![
Expr::Literal(ScalarValue::Int64(1)),
Expr::Literal(ScalarValue::Int64(2)),
])?
.project(vec![lit(1_i64), lit(2_i64)])?
.build()?;
let expected = "Projection: Int64(1), Int64(2)\
\n TableScan: test projection=Some([0])";
Expand Down
18 changes: 6 additions & 12 deletions rust/datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::sync::Arc;

use crate::error::{ExecutionError, Result};
use crate::logicalplan::{
Expr, FunctionMeta, LogicalPlan, LogicalPlanBuilder, Operator, ScalarValue,
lit, Expr, FunctionMeta, LogicalPlan, LogicalPlanBuilder, Operator, ScalarValue,
};

use arrow::datatypes::*;
Expand Down Expand Up @@ -262,14 +262,10 @@ impl<S: SchemaProvider> SqlToRel<S> {
/// Generate a relational expression from a SQL expression
pub fn sql_to_rex(&self, sql: &ASTNode, schema: &Schema) -> Result<Expr> {
match *sql {
ASTNode::SQLValue(sqlparser::sqlast::Value::Long(n)) => {
Ok(Expr::Literal(ScalarValue::Int64(n)))
}
ASTNode::SQLValue(sqlparser::sqlast::Value::Double(n)) => {
Ok(Expr::Literal(ScalarValue::Float64(n)))
}
ASTNode::SQLValue(sqlparser::sqlast::Value::Long(n)) => Ok(lit(n)),
ASTNode::SQLValue(sqlparser::sqlast::Value::Double(n)) => Ok(lit(n)),
ASTNode::SQLValue(sqlparser::sqlast::Value::SingleQuotedString(ref s)) => {
Ok(Expr::Literal(ScalarValue::Utf8(s.clone())))
Ok(lit(s.clone()))
}

ASTNode::SQLAliasedExpr(ref expr, ref alias) => Ok(Alias(
Expand Down Expand Up @@ -382,11 +378,9 @@ impl<S: SchemaProvider> SqlToRel<S> {
.iter()
.map(|a| match a {
ASTNode::SQLValue(sqlparser::sqlast::Value::Long(_)) => {
Ok(Expr::Literal(ScalarValue::UInt8(1)))
}
ASTNode::SQLWildcard => {
Ok(Expr::Literal(ScalarValue::UInt8(1)))
Ok(lit(1_u8))
}
ASTNode::SQLWildcard => Ok(lit(1_u8)),
_ => self.sql_to_rex(a, schema),
})
.collect::<Result<Vec<Expr>>>()?;
Expand Down