Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1788,22 +1788,24 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> {

#[cfg(test)]
mod tests {
use std::env;
use std::path::PathBuf;

use super::{super::options::CsvReadOptions, *};
use crate::assert_batches_eq;
use crate::execution::memory_pool::MemoryConsumer;
use crate::execution::runtime_env::RuntimeEnvBuilder;
use crate::test;
use crate::test_util::{plan_and_collect, populate_csv_partitions};
use arrow_schema::{DataType, TimeUnit};
use std::env;
use std::path::PathBuf;

use datafusion_common_runtime::SpawnedTask;

use crate::catalog::SchemaProvider;
use crate::execution::session_state::SessionStateBuilder;
use crate::physical_planner::PhysicalPlanner;
use async_trait::async_trait;
use datafusion_expr::planner::TypePlanner;
use sqlparser::ast;
use tempfile::TempDir;

#[tokio::test]
Expand Down Expand Up @@ -2200,6 +2202,29 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn custom_type_planner() -> Result<()> {
let state = SessionStateBuilder::new()
.with_default_features()
.with_type_planner(Arc::new(MyTypePlanner {}))
.build();
let ctx = SessionContext::new_with_state(state);
let result = ctx
.sql("SELECT DATETIME '2021-01-01 00:00:00'")
.await?
.collect()
.await?;
let expected = [
"+-----------------------------+",
"| Utf8(\"2021-01-01 00:00:00\") |",
"+-----------------------------+",
"| 2021-01-01T00:00:00 |",
"+-----------------------------+",
];
assert_batches_eq!(expected, &result);
Ok(())
}

struct MyPhysicalPlanner {}

#[async_trait]
Expand Down Expand Up @@ -2260,4 +2285,25 @@ mod tests {

Ok(ctx)
}

#[derive(Debug)]
struct MyTypePlanner {}

impl TypePlanner for MyTypePlanner {
fn plan_type(&self, sql_type: &ast::DataType) -> Result<Option<DataType>> {
match sql_type {
ast::DataType::Datetime(precision) => {
let precision = match precision {
Some(0) => TimeUnit::Second,
Some(3) => TimeUnit::Millisecond,
Some(6) => TimeUnit::Microsecond,
None | Some(9) => TimeUnit::Nanosecond,
_ => unreachable!(),
};
Ok(Some(DataType::Timestamp(precision, None)))
}
_ => Ok(None),
}
}
}
}
30 changes: 29 additions & 1 deletion datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::planner::{ExprPlanner, TypePlanner};
use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry};
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::var_provider::{is_system_variables, VarType};
Expand Down Expand Up @@ -128,6 +128,8 @@ pub struct SessionState {
analyzer: Analyzer,
/// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?`
expr_planners: Vec<Arc<dyn ExprPlanner>>,
/// Provides support for customising the SQL type planning
type_planner: Option<Arc<dyn TypePlanner>>,
/// Responsible for optimizing a logical plan
optimizer: Optimizer,
/// Responsible for optimizing a physical execution plan
Expand Down Expand Up @@ -192,6 +194,7 @@ impl Debug for SessionState {
.field("table_factories", &self.table_factories)
.field("function_factory", &self.function_factory)
.field("expr_planners", &self.expr_planners)
.field("type_planner", &self.type_planner)
.field("query_planners", &self.query_planner)
.field("analyzer", &self.analyzer)
.field("optimizer", &self.optimizer)
Expand Down Expand Up @@ -943,6 +946,7 @@ pub struct SessionStateBuilder {
session_id: Option<String>,
analyzer: Option<Analyzer>,
expr_planners: Option<Vec<Arc<dyn ExprPlanner>>>,
type_planner: Option<Arc<dyn TypePlanner>>,
optimizer: Option<Optimizer>,
physical_optimizers: Option<PhysicalOptimizer>,
query_planner: Option<Arc<dyn QueryPlanner + Send + Sync>>,
Expand Down Expand Up @@ -972,6 +976,7 @@ impl SessionStateBuilder {
session_id: None,
analyzer: None,
expr_planners: None,
type_planner: None,
optimizer: None,
physical_optimizers: None,
query_planner: None,
Expand Down Expand Up @@ -1019,6 +1024,7 @@ impl SessionStateBuilder {
session_id: None,
analyzer: Some(existing.analyzer),
expr_planners: Some(existing.expr_planners),
type_planner: existing.type_planner,
optimizer: Some(existing.optimizer),
physical_optimizers: Some(existing.physical_optimizers),
query_planner: Some(existing.query_planner),
Expand Down Expand Up @@ -1113,6 +1119,12 @@ impl SessionStateBuilder {
self
}

/// Set the [`TypePlanner`] used to customize the behavior of the SQL planner.
pub fn with_type_planner(mut self, type_planner: Arc<dyn TypePlanner>) -> Self {
self.type_planner = Some(type_planner);
self
}

/// Set the [`PhysicalOptimizerRule`]s used to optimize plans.
pub fn with_physical_optimizer_rules(
mut self,
Expand Down Expand Up @@ -1306,6 +1318,7 @@ impl SessionStateBuilder {
session_id,
analyzer,
expr_planners,
type_planner,
optimizer,
physical_optimizers,
query_planner,
Expand Down Expand Up @@ -1334,6 +1347,7 @@ impl SessionStateBuilder {
session_id: session_id.unwrap_or(Uuid::new_v4().to_string()),
analyzer: analyzer.unwrap_or_default(),
expr_planners: expr_planners.unwrap_or_default(),
type_planner,
optimizer: optimizer.unwrap_or_default(),
physical_optimizers: physical_optimizers.unwrap_or_default(),
query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})),
Expand Down Expand Up @@ -1444,6 +1458,11 @@ impl SessionStateBuilder {
&mut self.expr_planners
}

/// Returns the current type_planner value
pub fn type_planner(&mut self) -> &mut Option<Arc<dyn TypePlanner>> {
&mut self.type_planner
}

/// Returns the current optimizer value
pub fn optimizer(&mut self) -> &mut Option<Optimizer> {
&mut self.optimizer
Expand Down Expand Up @@ -1566,6 +1585,7 @@ impl Debug for SessionStateBuilder {
.field("table_factories", &self.table_factories)
.field("function_factory", &self.function_factory)
.field("expr_planners", &self.expr_planners)
.field("type_planner", &self.type_planner)
.field("query_planners", &self.query_planner)
.field("analyzer_rules", &self.analyzer_rules)
.field("analyzer", &self.analyzer)
Expand Down Expand Up @@ -1607,6 +1627,14 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
&self.state.expr_planners
}

fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
if let Some(type_planner) = &self.state.type_planner {
Some(Arc::clone(type_planner))
} else {
None
}
}

fn get_table_source(
&self,
name: TableReference,
Expand Down
18 changes: 17 additions & 1 deletion datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion_common::{
config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema,
Result, TableReference,
};
use sqlparser::ast;

use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF};

Expand Down Expand Up @@ -66,6 +67,11 @@ pub trait ContextProvider {
&[]
}

/// Getter for the data type planner
fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
None
}

/// Getter for a UDF description
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
/// Getter for a UDAF description
Expand Down Expand Up @@ -216,7 +222,7 @@ pub trait ExprPlanner: Debug + Send + Sync {
/// custom expressions.
#[derive(Debug, Clone)]
pub struct RawBinaryExpr {
pub op: sqlparser::ast::BinaryOperator,
pub op: ast::BinaryOperator,
pub left: Expr,
pub right: Expr,
}
Expand Down Expand Up @@ -249,3 +255,13 @@ pub enum PlannerResult<T> {
/// The raw expression could not be planned, and is returned unmodified
Original(T),
}

/// This trait allows users to customize the behavior of the data type planning
pub trait TypePlanner: Debug + Send + Sync {
/// Plan SQL type to DataFusion data type
///
/// Returns None if not possible
fn plan_type(&self, _sql_type: &ast::DataType) -> Result<Option<DataType>> {
Ok(None)
}
}
8 changes: 8 additions & 0 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}

pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
// First check if any of the registered type_planner can handle this type
if let Some(type_planner) = self.context_provider.get_type_planner() {
if let Some(data_type) = type_planner.plan_type(sql_type)? {
return Ok(data_type);
}
}

// If no type_planner can handle this type, use the default conversion
match sql_type {
SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => {
// Arrays may be multi-dimensional.
Expand Down
57 changes: 53 additions & 4 deletions datafusion/sql/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
use std::any::Any;
#[cfg(test)]
use std::collections::HashMap;
use std::fmt::Display;
use std::fmt::{Debug, Display};
use std::{sync::Arc, vec};

use arrow_schema::*;
use datafusion_common::config::ConfigOptions;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{plan_err, GetExt, Result, TableReference};
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
use datafusion_common::{plan_err, DFSchema, GetExt, Result, TableReference};
use datafusion_expr::planner::{ExprPlanner, PlannerResult, TypePlanner};
use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF};
use datafusion_functions_nested::expr_fn::make_array;
use datafusion_sql::planner::ContextProvider;

struct MockCsvType {}
Expand Down Expand Up @@ -54,6 +55,7 @@ pub(crate) struct MockSessionState {
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
expr_planners: Vec<Arc<dyn ExprPlanner>>,
type_planner: Option<Arc<dyn TypePlanner>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
pub config_options: ConfigOptions,
}
Expand All @@ -64,6 +66,11 @@ impl MockSessionState {
self
}

pub fn with_type_planner(mut self, type_planner: Arc<dyn TypePlanner>) -> Self {
self.type_planner = Some(type_planner);
self
}

pub fn with_scalar_function(mut self, scalar_function: Arc<ScalarUDF>) -> Self {
self.scalar_functions
.insert(scalar_function.name().to_string(), scalar_function);
Expand Down Expand Up @@ -259,6 +266,14 @@ impl ContextProvider for MockContextProvider {
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&self.state.expr_planners
}

fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
if let Some(type_planner) = &self.state.type_planner {
Some(Arc::clone(type_planner))
} else {
None
}
}
}

struct EmptyTable {
Expand All @@ -280,3 +295,37 @@ impl TableSource for EmptyTable {
Arc::clone(&self.table_schema)
}
}

#[derive(Debug)]
pub struct CustomTypePlanner {}

impl TypePlanner for CustomTypePlanner {
fn plan_type(&self, sql_type: &sqlparser::ast::DataType) -> Result<Option<DataType>> {
match sql_type {
sqlparser::ast::DataType::Datetime(precision) => {
let precision = match precision {
Some(0) => TimeUnit::Second,
Some(3) => TimeUnit::Millisecond,
Some(6) => TimeUnit::Microsecond,
None | Some(9) => TimeUnit::Nanosecond,
_ => unreachable!(),
};
Ok(Some(DataType::Timestamp(precision, None)))
}
_ => Ok(None),
}
}
}

#[derive(Debug)]
pub struct CustomExprPlanner {}

impl ExprPlanner for CustomExprPlanner {
fn plan_array_literal(
&self,
exprs: Vec<Expr>,
_schema: &DFSchema,
) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Planned(make_array(exprs)))
}
}
Loading