Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ impl<'a, T: Transaction> Binder<'a, T> {
let mut set = HashSet::new();
for col in columns.iter() {
let col_name = &col.name.value;
if !set.insert(col_name.clone()) {
return Err(DatabaseError::AmbiguousColumn(col_name.to_string()));
if !set.insert(col_name) {
return Err(DatabaseError::DuplicateColumn(col_name.clone()));
}
if !is_valid_identifier(col_name) {
return Err(DatabaseError::InvalidColumn(
Expand Down Expand Up @@ -122,7 +122,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
DataValue::clone(&value).cast(&column_desc.column_datatype)?;
column_desc.default = Some(Arc::new(cast_value));
} else {
unreachable!("'default' only for constant")
return Err(DatabaseError::UnsupportedStmt(
"'default' only for constant".to_string(),
));
}
}
_ => todo!(),
Expand Down
40 changes: 27 additions & 13 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use sqlparser::ast::{
use std::slice;
use std::sync::Arc;

use super::{lower_ident, Binder};
use super::{lower_ident, Binder, QueryBindStep};
use crate::expression::function::{FunctionSummary, ScalarFunction};
use crate::expression::{AliasType, ScalarExpression};
use crate::storage::Transaction;
Expand All @@ -27,6 +27,14 @@ macro_rules! try_alias {
};
}

macro_rules! try_default {
($table_name:expr, $column_name:expr) => {
if let (None, "default") = ($table_name, $column_name.as_str()) {
return Ok(ScalarExpression::Empty);
}
};
}

impl<'a, T: Transaction> Binder<'a, T> {
pub(crate) fn bind_expr(&mut self, expr: &Expr) -> Result<ScalarExpression, DatabaseError> {
match expr {
Expand Down Expand Up @@ -102,17 +110,21 @@ impl<'a, T: Transaction> Binder<'a, T> {
));
}
let column = sub_query_schema[0].clone();
let mut alias_column = ColumnCatalog::clone(&column);
alias_column.set_table_name(self.context.temp_table());

self.context.sub_query(sub_query);

Ok(ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(column)),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new(
alias_column,
)))),
})
if self.context.is_step(&QueryBindStep::Where) {
let mut alias_column = ColumnCatalog::clone(&column);
alias_column.set_table_name(self.context.temp_table());

Ok(ScalarExpression::Alias {
expr: Box::new(ScalarExpression::ColumnRef(column)),
alias: AliasType::Expr(Box::new(ScalarExpression::ColumnRef(Arc::new(
alias_column,
)))),
})
} else {
Ok(ScalarExpression::ColumnRef(column))
}
}
Expr::Tuple(exprs) => {
let mut bond_exprs = Vec::with_capacity(exprs.len());
Expand Down Expand Up @@ -215,8 +227,11 @@ impl<'a, T: Transaction> Binder<'a, T> {
))
}
};
try_alias!(self.context, column_name);
if self.context.allow_default {
try_default!(&table_name, column_name);
}
if let Some(table) = table_name.or(bind_table_name) {
try_alias!(self.context, column_name);
let table_catalog = self
.context
.table(Arc::new(table.clone()))
Expand All @@ -227,10 +242,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
.ok_or_else(|| DatabaseError::NotFound("column", column_name))?;
Ok(ScalarExpression::ColumnRef(column_catalog.clone()))
} else {
try_alias!(self.context, column_name);
// handle col syntax
let mut got_column = None;
for (table_catalog, _) in self.context.bind_table.values() {
for table_catalog in self.context.bind_table.values() {
if let Some(column_catalog) = table_catalog.get_column_by_name(&column_name) {
got_column = Some(column_catalog);
}
Expand Down
47 changes: 25 additions & 22 deletions src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ impl<'a, T: Transaction> Binder<'a, T> {
expr_rows: &Vec<Vec<Expr>>,
is_overwrite: bool,
) -> Result<LogicalPlan, DatabaseError> {
// FIXME: Make it better to detect the current BindStep
self.context.allow_default = true;
let table_name = Arc::new(lower_case_name(name)?);

if let Some(table) = self.context.table(table_name.clone()) {
Expand All @@ -43,7 +45,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
Some(table_name.to_string()),
)? {
ScalarExpression::ColumnRef(catalog) => columns.push(catalog),
_ => unreachable!(),
_ => return Err(DatabaseError::UnsupportedStmt(ident.to_string())),
}
}
if values_len != columns.len() {
Expand All @@ -53,37 +55,41 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
let schema_ref = _schema_ref.ok_or(DatabaseError::ColumnsEmpty)?;
let mut rows = Vec::with_capacity(expr_rows.len());

for expr_row in expr_rows {
if expr_row.len() != values_len {
return Err(DatabaseError::ValuesLenMismatch(expr_row.len(), values_len));
}
let mut row = Vec::with_capacity(expr_row.len());

for (i, expr) in expr_row.iter().enumerate() {
match &self.bind_expr(expr)? {
ScalarExpression::Constant(value) => {
let mut expression = self.bind_expr(expr)?;

expression.constant_calculation()?;
match expression {
ScalarExpression::Constant(mut value) => {
let ty = schema_ref[i].datatype();
// Check if the value length is too long
value.check_len(schema_ref[i].datatype())?;
let cast_value =
DataValue::clone(value).cast(schema_ref[i].datatype())?;
row.push(Arc::new(cast_value))
}
ScalarExpression::Unary { expr, op, .. } => {
if let ScalarExpression::Constant(value) = expr.as_ref() {
row.push(Arc::new(
DataValue::unary_op(value, op)?
.cast(schema_ref[i].datatype())?,
))
} else {
unreachable!()
value.check_len(ty)?;

if value.logical_type() != *ty {
value = Arc::new(DataValue::clone(&value).cast(ty)?);
}
row.push(value);
}
ScalarExpression::Empty => {
row.push(schema_ref[i].default_value().ok_or_else(|| {
DatabaseError::InvalidColumn(
"column does not exist default".to_string(),
)
})?);
}
_ => unreachable!(),
_ => return Err(DatabaseError::UnsupportedStmt(expr.to_string())),
}
}

rows.push(row);
}
self.context.allow_default = false;
let values_plan = self.bind_values(rows, schema_ref);

Ok(LogicalPlan::new(
Expand All @@ -94,10 +100,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
vec![values_plan],
))
} else {
Err(DatabaseError::InvalidTable(format!(
"not found table {}",
table_name
)))
Err(DatabaseError::TableNotFound)
}
}

Expand Down
19 changes: 11 additions & 8 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mod truncate;
mod update;

use sqlparser::ast::{Ident, ObjectName, ObjectType, SetExpr, Statement};
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;

use crate::catalog::{TableCatalog, TableName};
Expand Down Expand Up @@ -50,7 +50,8 @@ pub enum QueryBindStep {
pub struct BinderContext<'a, T: Transaction> {
functions: &'a Functions,
pub(crate) transaction: &'a T,
pub(crate) bind_table: HashMap<TableName, (&'a TableCatalog, Option<JoinType>)>,
// Tips: When there are multiple tables and Wildcard, use BTreeMap to ensure that the order of the output tables is certain.
pub(crate) bind_table: BTreeMap<(TableName, Option<JoinType>), &'a TableCatalog>,
// alias
expr_aliases: HashMap<String, ScalarExpression>,
table_aliases: HashMap<String, TableName>,
Expand All @@ -62,6 +63,7 @@ pub struct BinderContext<'a, T: Transaction> {
sub_queries: HashMap<QueryBindStep, Vec<LogicalPlan>>,

temp_table_id: usize,
pub(crate) allow_default: bool,
}

impl<'a, T: Transaction> BinderContext<'a, T> {
Expand All @@ -77,6 +79,7 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
bind_step: QueryBindStep::From,
sub_queries: Default::default(),
temp_table_id: 0,
allow_default: false,
}
}

Expand All @@ -89,6 +92,10 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
self.bind_step = bind_step;
}

pub fn is_step(&self, bind_step: &QueryBindStep) -> bool {
&self.bind_step == bind_step
}

pub fn sub_query(&mut self, sub_query: LogicalPlan) {
self.sub_queries
.entry(self.bind_step)
Expand Down Expand Up @@ -120,12 +127,8 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
}
.ok_or(DatabaseError::TableNotFound)?;

let old_table = self
.bind_table
.insert(table_name.clone(), (table, join_type));
if matches!(old_table, Some((_, Some(_)))) {
return Err(DatabaseError::Duplicated("table", table_name.to_string()));
}
self.bind_table
.insert((table_name.clone(), join_type), table);

Ok(table)
}
Expand Down
57 changes: 41 additions & 16 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
types::value::DataValue,
};

use super::{lower_case_name, lower_ident, Binder, QueryBindStep};
use super::{lower_case_name, lower_ident, Binder, BinderContext, QueryBindStep};

use crate::catalog::{ColumnCatalog, ColumnSummary, TableName};
use crate::errors::DatabaseError;
Expand Down Expand Up @@ -338,7 +338,10 @@ impl<'a, T: Transaction> Binder<'a, T> {
let scan_op = ScanOperator::build(table_name.clone(), table_catalog);

if let Some(TableAlias { name, columns }) = alias {
self.register_alias(columns, name.value.to_lowercase(), table_name.clone())?;
let alias = lower_ident(name);
self.register_alias(columns, alias.clone(), table_name.clone())?;

return Ok((Arc::new(alias), scan_op));
}

Ok((table_name, scan_op))
Expand Down Expand Up @@ -371,7 +374,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
});
}
SelectItem::Wildcard(_) => {
for table_name in self.context.bind_table.keys() {
for (table_name, _) in self.context.bind_table.keys() {
self.bind_table_column_refs(&mut select_items, table_name.clone())?;
}
}
Expand Down Expand Up @@ -443,19 +446,16 @@ impl<'a, T: Transaction> Binder<'a, T> {
};
let (right_table, right) = self.bind_single_table_ref(relation, Some(join_type))?;
let right_table = Self::unpack_name(right_table, false);
let fn_table = |context: &BinderContext<_>, table| {
context
.table(table)
.map(|table| table.schema_ref())
.cloned()
.ok_or(DatabaseError::TableNotFound)
};

let left_table = self
.context
.table(left_table)
.map(|table| table.schema_ref())
.cloned()
.ok_or(DatabaseError::TableNotFound)?;
let right_table = self
.context
.table(right_table)
.map(|table| table.schema_ref())
.cloned()
.ok_or(DatabaseError::TableNotFound)?;
let left_table = fn_table(&self.context, left_table.clone())?;
let right_table = fn_table(&self.context, right_table.clone())?;

let on = match joint_condition {
Some(constraint) => self.bind_join_constraint(&left_table, &right_table, constraint)?,
Expand Down Expand Up @@ -605,7 +605,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
let mut left_table_force_nullable = false;
let mut left_table = None;

for (table, join_option) in bind_tables.values() {
for ((_, join_option), table) in bind_tables {
if let Some(join_type) = join_option {
let (left_force_nullable, right_force_nullable) = joins_nullable(join_type);
table_force_nullable.push((table, right_force_nullable));
Expand Down Expand Up @@ -671,6 +671,31 @@ impl<'a, T: Transaction> Binder<'a, T> {
filter: join_filter,
})
}
JoinConstraint::Using(idents) => {
let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![];
let fn_column = |schema: &Schema, ident: &Ident| {
schema
.iter()
.find(|column| column.name() == lower_ident(ident))
.map(|column| ScalarExpression::ColumnRef(column.clone()))
};

for ident in idents {
if let (Some(left_column), Some(right_column)) = (
fn_column(left_schema, ident),
fn_column(right_schema, ident),
) {
on_keys.push((left_column, right_column));
} else {
return Err(DatabaseError::InvalidColumn("not found column".to_string()))?;
}
}
Ok(JoinCondition::On {
on: on_keys,
filter: None,
})
}
JoinConstraint::None => Ok(JoinCondition::None),
_ => unimplemented!("not supported join constraint {:?}", constraint),
}
}
Expand Down
Loading