Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/execution/expression_executor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use arrow::array::ArrayRef;
use arrow::compute::cast;
use arrow::compute::{cast_with_options, CastOptions};
use arrow::record_batch::RecordBatch;

use super::ExecutorError;
Expand Down Expand Up @@ -32,7 +32,8 @@ impl ExpressionExecutor {
BoundExpression::BoundCastExpression(e) => {
let child_result = Self::execute_internal(&e.child, input)?;
let to_type = e.base.return_type.clone().into();
cast(&child_result, &to_type)?
let options = CastOptions { safe: e.try_cast };
cast_with_options(&child_result, &to_type, &options)?
}
})
}
Expand Down
1 change: 0 additions & 1 deletion src/planner_v2/binder/expression/bind_cast_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub struct BoundCastExpression {
pub(crate) base: BoundExpressionBase,
/// The child type
pub(crate) child: Box<BoundExpression>,
#[allow(dead_code)]
/// Whether to use try_cast or not. try_cast converts cast failures into NULLs instead of
/// throwing an error.
pub(crate) try_cast: bool,
Expand Down
24 changes: 9 additions & 15 deletions src/planner_v2/binder/query_node/plan_select_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,15 @@ impl Binder {
source_types.iter().zip(target_types.iter()).enumerate()
{
if source_type != target_type {
if LogicalType::can_implicit_cast(source_type, target_type) {
let alias = node.base.expressioins[idx].alias();
node.base.expressioins[idx] = BoundCastExpression::add_cast_to_type(
node.base.expressioins[idx].clone(),
target_type.clone(),
alias,
false,
);
node.base.types[idx] = target_type.clone();
} else {
return Err(BindError::Internal(format!(
"cannot cast {:?} to {:?}",
source_type, target_type
)));
}
// differing types, have to add a cast but may be lossy
let alias = node.base.expressioins[idx].alias();
node.base.expressioins[idx] = BoundCastExpression::add_cast_to_type(
node.base.expressioins[idx].clone(),
target_type.clone(),
alias,
false,
);
node.base.types[idx] = target_type.clone();
}
}
Ok(())
Expand Down
2 changes: 2 additions & 0 deletions src/planner_v2/binder/statement/bind_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ impl Binder {
let mut plan = select_node.plan;
// cast inserted types to expected types when necessary
self.cast_logical_operator_to_types(&inserted_types, &expected_types, &mut plan)?;
// TODO: add debug level log for plan
// println!("plan: {:#?}", plan);

let root = LogicalInsert::new(
LogicalOperatorBase::new(vec![plan], vec![], vec![]),
Expand Down
19 changes: 18 additions & 1 deletion src/planner_v2/binder/tableref/bind_expression_list_ref.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use derive_new::new;
use sqlparser::ast::Values;

use crate::planner_v2::{BindError, Binder, BoundExpression, ExpressionBinder};
use crate::planner_v2::{
BindError, Binder, BoundCastExpression, BoundExpression, ExpressionBinder,
};
use crate::types_v2::LogicalType;

pub static VALUES_LIST_ALIAS: &str = "valueslist";
Expand Down Expand Up @@ -56,6 +58,21 @@ impl Binder {
}
bound_expr_list.push(bound_expr_row);
}
// insert values contains SqlNull, the expr should be cast to the max logical type
for exprs in bound_expr_list.iter_mut() {
for (idx, bound_expr) in exprs.iter_mut().enumerate() {
if bound_expr.return_type() == LogicalType::SqlNull {
let alias = bound_expr.alias().clone();
*bound_expr = BoundCastExpression::add_cast_to_type(
bound_expr.clone(),
types[idx].clone(),
alias,
false,
)
}
}
}

let table_index = self.generate_table_index();
self.bind_context.add_generic_binding(
VALUES_LIST_ALIAS.to_string(),
Expand Down
78 changes: 67 additions & 11 deletions src/types_v2/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::TypeError;
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum LogicalType {
Invalid,
SqlNull,
Boolean,
Tinyint,
UTinyint,
Expand Down Expand Up @@ -36,37 +37,91 @@ impl LogicalType {
)
}

pub fn is_signed_numeric(&self) -> bool {
matches!(
self,
LogicalType::Tinyint
| LogicalType::Smallint
| LogicalType::Integer
| LogicalType::Bigint
)
}

pub fn is_unsigned_numeric(&self) -> bool {
matches!(
self,
LogicalType::UTinyint
| LogicalType::USmallint
| LogicalType::UInteger
| LogicalType::UBigint
)
}

pub fn max_logical_type(
left: &LogicalType,
right: &LogicalType,
) -> Result<LogicalType, TypeError> {
if left == right {
return Ok(left.clone());
}
match (left, right) {
// SqlNull type can be cast to anything
(LogicalType::SqlNull, _) => return Ok(right.clone()),
(_, LogicalType::SqlNull) => return Ok(left.clone()),
_ => {}
}
if left.is_numeric() && right.is_numeric() {
if LogicalType::can_implicit_cast(left, right) {
return Ok(right.clone());
} else if LogicalType::can_implicit_cast(right, left) {
return Ok(left.clone());
} else {
return Err(TypeError::InternalError(format!(
"can not implicit cast {:?} to {:?}",
left, right
)));
}
return LogicalType::combine_numeric_types(left, right);
}
Err(TypeError::InternalError(format!(
"can not compare two types: {:?} and {:?}",
left, right
)))
}

pub fn can_implicit_cast(from: &LogicalType, to: &LogicalType) -> bool {
fn combine_numeric_types(
left: &LogicalType,
right: &LogicalType,
) -> Result<LogicalType, TypeError> {
if left == right {
return Ok(left.clone());
}
if left.is_signed_numeric() && right.is_unsigned_numeric() {
// this method is symmetric
// arrange it so the left type is smaller
// to limit the number of options we need to check
return LogicalType::combine_numeric_types(right, left);
}

if LogicalType::can_implicit_cast(left, right) {
return Ok(right.clone());
}
if LogicalType::can_implicit_cast(right, left) {
return Ok(left.clone());
}
// we can't cast implicitly either way and types are not equal
// this happens when left is signed and right is unsigned
// e.g. INTEGER and UINTEGER
// in this case we need to upcast to make sure the types fit
match (left, right) {
(LogicalType::Bigint, _) | (_, LogicalType::UBigint) => Ok(LogicalType::Double),
(LogicalType::Integer, _) | (_, LogicalType::UInteger) => Ok(LogicalType::Bigint),
(LogicalType::Smallint, _) | (_, LogicalType::USmallint) => Ok(LogicalType::Integer),
(LogicalType::Tinyint, _) | (_, LogicalType::UTinyint) => Ok(LogicalType::Smallint),
_ => Err(TypeError::InternalError(format!(
"can not combine these numeric types {:?} and {:?}",
left, right
))),
}
}

fn can_implicit_cast(from: &LogicalType, to: &LogicalType) -> bool {
if from == to {
return true;
}
match from {
LogicalType::Invalid => false,
LogicalType::SqlNull => true,
LogicalType::Boolean => false,
LogicalType::Tinyint => matches!(
to,
Expand Down Expand Up @@ -160,6 +215,7 @@ impl From<LogicalType> for arrow::datatypes::DataType {
use arrow::datatypes::DataType;
match value {
LogicalType::Invalid => panic!("invalid logical type"),
LogicalType::SqlNull => DataType::Null,
LogicalType::Boolean => DataType::Boolean,
LogicalType::Tinyint => DataType::Int8,
LogicalType::UTinyint => DataType::UInt8,
Expand Down
2 changes: 1 addition & 1 deletion src/types_v2/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl ScalarValue {

pub fn get_logical_type(&self) -> LogicalType {
match self {
ScalarValue::Null => LogicalType::Invalid,
ScalarValue::Null => LogicalType::SqlNull,
ScalarValue::Boolean(_) => LogicalType::Boolean,
ScalarValue::Float32(_) => LogicalType::Float,
ScalarValue::Float64(_) => LogicalType::Double,
Expand Down
38 changes: 8 additions & 30 deletions tests/slt/create_table.slt
Original file line number Diff line number Diff line change
@@ -1,43 +1,21 @@
onlyif sqlrs_v2
statement ok
create table t1(a varchar, b varchar, c varchar);
create table t1(v1 varchar, v2 varchar, v3 varchar);
insert into t1 values('a', 'b', 'c');

onlyif sqlrs_v2
statement ok
insert into t1(c, b) values ('0','4'),('1','5');

onlyif sqlrs_v2
statement ok
insert into t1 values ('2','7','9');

onlyif sqlrs_v2
query III
select a, c, b from t1;
----
NULL 0 4
NULL 1 5
2 9 7
statement error
create table t1(v1 int);

onlyif sqlrs_v2
statement ok
create table t2(a int, b int, c int);

onlyif sqlrs_v2
statement ok
insert into t2(c, b, a) values (0, 4, 1), (1, 5, 2);
create table t2(v1 boolean, v2 tinyint, v3 smallint, v4 int, v5 bigint, v6 float, v7 double, v8 varchar);
insert into t2 values(true, 1, 2, 3, 4, 5.1, 6.2, '7');

onlyif sqlrs_v2
query III
select c, b, a from t2;
----
0 4 1
1 5 2

# Test insert type cast
onlyif sqlrs_v2
statement ok
create table t3(a TINYINT UNSIGNED);

onlyif sqlrs_v2
statement error
insert into t3(a) values (1481);
create table t3(v1 boolean, v2 tinyint unsigned, v3 smallint unsigned, v4 int unsigned, v5 bigint unsigned, v6 float, v7 double, v8 varchar);
insert into t3 values(true, 1, 2, 3, 4, 5.1, 6.2, '7');
69 changes: 69 additions & 0 deletions tests/slt/insert_table.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Test common insert case
onlyif sqlrs_v2
statement ok
create table t1(v1 varchar, v2 varchar, v3 varchar);

onlyif sqlrs_v2
statement ok
insert into t1(v3, v2) values ('0','4'), ('1','5');

onlyif sqlrs_v2
statement ok
insert into t1 values ('2','7','9');

onlyif sqlrs_v2
query III
select v1, v3, v2 from t1;
----
NULL 0 4
NULL 1 5
2 9 7


# Test insert value cast type
onlyif sqlrs_v2
statement ok
create table t2(v1 int, v2 int, v3 int);

onlyif sqlrs_v2
statement ok
insert into t2(v3, v2, v1) values (0, 4, 1), (1, 5, 2);

onlyif sqlrs_v2
query III
select v3, v2, v1 from t2;
----
0 4 1
1 5 2


# Test insert type cast
onlyif sqlrs_v2
statement ok
create table t3(v1 TINYINT UNSIGNED);

onlyif sqlrs_v2
statement error
insert into t3(v1) values (1481);


# Test insert null values
onlyif sqlrs_v2
statement ok
create table t4(v1 varchar, v2 smallint unsigned, v3 bigint unsigned);

onlyif sqlrs_v2
statement ok
insert into t4 values (NULL, 1, 2), ('', 3, NULL);

onlyif sqlrs_v2
statement ok
insert into t4 values (NULL, NULL, NULL);

onlyif sqlrs_v2
query III
select v1, v2, v3 from t4;
----
NULL 1 2
(empty) 3 NULL
NULL NULL NULL