Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ run:

debug:
RUST_BACKTRACE=1 cargo run

debug_v2:
ENABLE_V2=1 RUST_BACKTRACE=1 cargo run
3 changes: 2 additions & 1 deletion src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::env;
use std::fs::File;
use std::sync::Arc;

Expand All @@ -13,7 +14,7 @@ pub async fn interactive(db: Database, client_context: Arc<ClientContext>) -> Re
let mut rl = Editor::<()>::new()?;
load_history(&mut rl);

let mut enable_v2 = false;
let mut enable_v2 = env::var("ENABLE_V2").unwrap_or_else(|_| "0".to_string()) == "1";

loop {
let read_sql = read_sql(&mut rl);
Expand Down
6 changes: 6 additions & 0 deletions src/execution/expression_executor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use arrow::array::ArrayRef;
use arrow::compute::cast;
use arrow::record_batch::RecordBatch;

use super::ExecutorError;
Expand Down Expand Up @@ -28,6 +29,11 @@ impl ExpressionExecutor {
BoundExpression::BoundColumnRefExpression(_) => todo!(),
BoundExpression::BoundConstantExpression(e) => e.value.to_array(),
BoundExpression::BoundReferenceExpression(e) => input.column(e.index).clone(),
BoundExpression::BoundCastExpression(e) => {
let child_result = Self::execute_internal(&e.child, input)?;
let to_type = e.base.return_type.clone().into();
cast(&child_result, &to_type)?
}
})
}
}
34 changes: 34 additions & 0 deletions src/planner_v2/binder/expression/bind_cast_expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use derive_new::new;

use super::{BoundExpression, BoundExpressionBase};
use crate::types_v2::LogicalType;

#[derive(new, Debug, Clone)]
pub struct BoundCastExpression {
pub(crate) base: BoundExpressionBase,
/// The child type
pub(crate) child: Box<BoundExpression>,
#[allow(dead_code)]
/// Whether to use try_cast or not. try_cast converts cast failures into NULLs instead of
/// throwing an error.
pub(crate) try_cast: bool,
}

impl BoundCastExpression {
pub fn add_cast_to_type(
expr: BoundExpression,
target_type: LogicalType,
alias: String,
try_cast: bool,
) -> BoundExpression {
if expr.return_type() == target_type {
return expr;
}
let base = BoundExpressionBase::new(alias, target_type);
BoundExpression::BoundCastExpression(BoundCastExpression::new(
base,
Box::new(expr),
try_cast,
))
}
}
5 changes: 5 additions & 0 deletions src/planner_v2/binder/expression/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod bind_cast_expression;
mod bind_column_ref_expression;
mod bind_constant_expression;
mod bind_reference_expression;
mod column_binding;

pub use bind_cast_expression::*;
pub use bind_column_ref_expression::*;
pub use bind_constant_expression::*;
pub use bind_reference_expression::*;
Expand All @@ -24,6 +26,7 @@ pub enum BoundExpression {
BoundColumnRefExpression(BoundColumnRefExpression),
BoundConstantExpression(BoundConstantExpression),
BoundReferenceExpression(BoundReferenceExpression),
BoundCastExpression(BoundCastExpression),
}

impl BoundExpression {
Expand All @@ -32,6 +35,7 @@ impl BoundExpression {
BoundExpression::BoundColumnRefExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundConstantExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundReferenceExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundCastExpression(expr) => expr.base.return_type.clone(),
}
}

Expand All @@ -40,6 +44,7 @@ impl BoundExpression {
BoundExpression::BoundColumnRefExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundConstantExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundReferenceExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundCastExpression(expr) => expr.base.alias.clone(),
}
}
}
46 changes: 45 additions & 1 deletion src/planner_v2/binder/query_node/plan_select_node.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use super::BoundSelectNode;
use crate::planner_v2::BoundTableRef::{BoundBaseTableRef, BoundExpressionListRef};
use crate::planner_v2::{
BindError, Binder, BoundStatement, LogicalOperator, LogicalOperatorBase, LogicalProjection,
BindError, Binder, BoundCastExpression, BoundStatement, LogicalOperator, LogicalOperatorBase,
LogicalProjection,
};
use crate::types_v2::LogicalType;

impl Binder {
pub fn create_plan_for_select_node(
Expand All @@ -23,4 +25,46 @@ impl Binder {

Ok(BoundStatement::new(root, node.types, node.names))
}

pub fn cast_logical_operator_to_types(
&mut self,
source_types: &[LogicalType],
target_types: &[LogicalType],
op: &mut LogicalOperator,
) -> Result<(), BindError> {
assert!(source_types.len() == target_types.len());
if source_types == target_types {
// source and target types are equal: don't need to cast
return Ok(());
}
if let LogicalOperator::LogicalProjection(node) = op {
// "node" is a projection; we can just do the casts in there
assert!(node.base.expressioins.len() == source_types.len());
for (idx, (source_type, target_type)) in
source_types.iter().zip(target_types.iter()).enumerate()
{
if source_type != target_type {
if LogicalType::can_implicit_cast(source_type, target_type) {
let alias = node.base.expressioins[idx].alias();
node.base.expressioins[idx] = BoundCastExpression::add_cast_to_type(
node.base.expressioins[idx].clone(),
target_type.clone(),
alias,
false,
);
node.base.types[idx] = target_type.clone();
} else {
return Err(BindError::Internal(format!(
"cannot cast {:?} to {:?}",
source_type, target_type
)));
}
}
}
Ok(())
} else {
// found a non-projection operator, push a new projection containing the casts
todo!();
}
}
}
12 changes: 8 additions & 4 deletions src/planner_v2/binder/statement/bind_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ impl Binder {

let select_node = self.bind_select_node(source)?;
let expected_columns_cnt = named_column_indices.len();

// special case: check if we are inserting from a VALUES statement
if let BoundTableRef::BoundExpressionListRef(table_ref) = &select_node.from_table {
// CheckInsertColumnCountMismatch
let insert_columns_cnt = table_ref.values.first().unwrap().len();
Expand All @@ -79,12 +81,14 @@ impl Binder {
expected_columns_cnt, insert_columns_cnt
)));
}
};

// TODO: cast types
}

let select_node = self.create_plan_for_select_node(select_node)?;
let plan = select_node.plan;
let inserted_types = select_node.types;
let mut plan = select_node.plan;
// cast inserted types to expected types when necessary
self.cast_logical_operator_to_types(&inserted_types, &expected_types, &mut plan)?;

let root = LogicalInsert::new(
LogicalOperatorBase::new(vec![plan], vec![], vec![]),
column_index_list,
Expand Down
26 changes: 19 additions & 7 deletions src/planner_v2/binder/tableref/bind_expression_list_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,37 @@ impl Binder {
&mut self,
values: &Values,
) -> Result<BoundExpressionListRef, BindError> {
// ensure all values lists are the same length
let mut values_cnt = 0;
for val_expr_list in values.0.iter() {
if values_cnt == 0 {
values_cnt = val_expr_list.len();
} else if values_cnt != val_expr_list.len() {
return Err(BindError::Internal(
"VALUES lists must all be the same length".to_string(),
));
}
}

let mut bound_expr_list = vec![];
let mut names = vec![];
let mut types = vec![];
let mut finish_name = false;
let mut names = vec!["".to_string(); values_cnt];
let mut types = vec![LogicalType::Invalid; values_cnt];

let mut expr_binder = ExpressionBinder::new(self);

for val_expr_list in values.0.iter() {
let mut bound_expr_row = vec![];
for (idx, expr) in val_expr_list.iter().enumerate() {
let bound_expr = expr_binder.bind_expression(expr, &mut vec![], &mut vec![])?;
if !finish_name {
names.push(format!("col{}", idx));
types.push(bound_expr.return_type());
names[idx] = format!("col{}", idx);
if types[idx] == LogicalType::Invalid {
types[idx] = bound_expr.return_type().clone();
}
// use values max type as the column type
types[idx] = LogicalType::max_logical_type(&types[idx], &bound_expr.return_type())?;
bound_expr_row.push(bound_expr);
}
bound_expr_list.push(bound_expr_row);
finish_name = true;
}
let table_index = self.generate_table_index();
self.bind_context.add_generic_binding(
Expand Down
3 changes: 2 additions & 1 deletion src/planner_v2/expression_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use super::BoundExpression;
pub struct ExpressionIterator;

impl ExpressionIterator {
pub fn enumerate_children<F>(expr: &mut BoundExpression, _callback: F)
pub fn enumerate_children<F>(expr: &mut BoundExpression, callback: F)
where
F: Fn(&mut BoundExpression),
{
Expand All @@ -13,6 +13,7 @@ impl ExpressionIterator {
| BoundExpression::BoundReferenceExpression(_) => {
// these node types have no children
}
BoundExpression::BoundCastExpression(e) => callback(&mut e.child),
}
}
}
8 changes: 6 additions & 2 deletions src/planner_v2/logical_operator_visitor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
BoundColumnRefExpression, BoundConstantExpression, BoundExpression, BoundReferenceExpression,
ExpressionIterator, LogicalOperator,
BoundCastExpression, BoundColumnRefExpression, BoundConstantExpression, BoundExpression,
BoundReferenceExpression, ExpressionIterator, LogicalOperator,
};

/// Visitor pattern on logical operators, also includes rewrite expression ability.
Expand Down Expand Up @@ -34,6 +34,7 @@ pub trait LogicalOperatorVisitor {
BoundExpression::BoundColumnRefExpression(e) => self.visit_replace_column_ref(e),
BoundExpression::BoundConstantExpression(e) => self.visit_replace_constant(e),
BoundExpression::BoundReferenceExpression(e) => self.visit_replace_reference(e),
BoundExpression::BoundCastExpression(e) => self.visit_replace_cast(e),
};
if let Some(new_expr) = result {
*expr = new_expr;
Expand All @@ -55,4 +56,7 @@ pub trait LogicalOperatorVisitor {
fn visit_replace_reference(&self, _: &BoundReferenceExpression) -> Option<BoundExpression> {
None
}
fn visit_replace_cast(&self, _: &BoundCastExpression) -> Option<BoundExpression> {
None
}
}
10 changes: 10 additions & 0 deletions src/planner_v2/operator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ impl LogicalOperator {
}
}

pub fn types(&self) -> &[LogicalType] {
match self {
LogicalOperator::LogicalCreateTable(op) => &op.base.types,
LogicalOperator::LogicalExpressionGet(op) => &op.base.types,
LogicalOperator::LogicalInsert(op) => &op.base.types,
LogicalOperator::LogicalGet(op) => &op.base.types,
LogicalOperator::LogicalProjection(op) => &op.base.types,
}
}

pub fn get_column_bindings(&self) -> Vec<ColumnBinding> {
let default = vec![ColumnBinding::new(0, 0)];
match self {
Expand Down
2 changes: 2 additions & 0 deletions src/types_v2/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ pub enum TypeError {
NotImplementedArrowDataType(String),
#[error("not implemented sqlparser datatype: {0}")]
NotImplementedSqlparserDataType(String),
#[error("internal error: {0}")]
InternalError(String),
}
Loading