From 2801bee9ceb074c89cc79e38ba029bcce216f039 Mon Sep 17 00:00:00 2001 From: Fedomn Date: Mon, 5 Dec 2022 23:23:37 +0800 Subject: [PATCH] feat(planner): new planner v2 which supports create/insert table and simple select Signed-off-by: Fedomn --- Cargo.lock | 12 + Cargo.toml | 1 + src/catalog_v2/catalog.rs | 60 +++ src/catalog_v2/catalog_set.rs | 49 +++ src/catalog_v2/constants.rs | 1 + src/catalog_v2/entry/mod.rs | 27 ++ src/catalog_v2/entry/schema_catalog_entry.rs | 38 ++ src/catalog_v2/entry/table_catalog_entry.rs | 64 +++ src/catalog_v2/errors.rs | 11 + src/catalog_v2/mod.rs | 11 + src/cli.rs | 21 +- src/execution/column_binding_resolver.rs | 39 ++ src/execution/expression_executor.rs | 33 ++ src/execution/mod.rs | 55 +++ src/execution/physical_plan/mod.rs | 30 ++ .../physical_plan/physical_create_table.rs | 18 + .../physical_plan/physical_expression_scan.rs | 29 ++ .../physical_plan/physical_insert.rs | 46 ++ .../physical_plan/physical_projection.rs | 25 ++ .../physical_plan/physical_table_scan.rs | 25 ++ src/execution/physical_plan_generator.rs | 36 ++ src/execution/util.rs | 27 ++ .../volcano_executor/create_table.rs | 41 ++ .../volcano_executor/expression_scan.rs | 36 ++ src/execution/volcano_executor/insert.rs | 62 +++ src/execution/volcano_executor/mod.rs | 61 +++ src/execution/volcano_executor/projection.rs | 31 ++ src/execution/volcano_executor/table_scan.rs | 28 ++ src/lib.rs | 6 + src/main.rs | 7 +- src/main_entry/client_context.rs | 124 ++++++ src/main_entry/db.rs | 20 + src/main_entry/errors.rs | 35 ++ src/main_entry/mod.rs | 14 + src/main_entry/pending_query_result.rs | 45 ++ src/main_entry/prepared_statement_data.rs | 18 + src/main_entry/query_context.rs | 32 ++ src/main_entry/query_result.rs | 22 + src/parser/mod.rs | 10 + src/planner_v2/binder/bind_context.rs | 102 +++++ src/planner_v2/binder/binding.rs | 55 +++ src/planner_v2/binder/errors.rs | 23 + .../expression/bind_column_ref_expression.rs | 74 ++++ .../expression/bind_constant_expression.rs | 28 ++ .../expression/bind_reference_expression.rs | 11 + .../binder/expression/column_binding.rs | 7 + src/planner_v2/binder/expression/mod.rs | 45 ++ src/planner_v2/binder/mod.rs | 69 +++ .../binder/query_node/bind_select_node.rs | 100 +++++ src/planner_v2/binder/query_node/mod.rs | 4 + .../binder/query_node/plan_select_node.rs | 26 ++ .../binder/statement/bind_create.rs | 45 ++ .../binder/statement/bind_insert.rs | 103 +++++ .../binder/statement/bind_select.rs | 16 + .../binder/statement/create_info.rs | 17 + src/planner_v2/binder/statement/mod.rs | 32 ++ .../binder/tableref/bind_base_table_ref.rs | 62 +++ .../tableref/bind_expression_list_ref.rs | 57 +++ src/planner_v2/binder/tableref/mod.rs | 34 ++ .../binder/tableref/plan_base_table_ref.rs | 11 + .../tableref/plan_expression_list_ref.rs | 16 + src/planner_v2/binder/util.rs | 29 ++ src/planner_v2/constants.rs | 1 + src/planner_v2/errors.rs | 11 + src/planner_v2/expression_binder.rs | 38 ++ src/planner_v2/expression_iterator.rs | 18 + src/planner_v2/logical_operator_visitor.rs | 58 +++ src/planner_v2/mod.rs | 54 +++ .../operator/logical_create_table.rs | 11 + .../operator/logical_expression_get.rs | 16 + src/planner_v2/operator/logical_get.rs | 22 + src/planner_v2/operator/logical_insert.rs | 15 + src/planner_v2/operator/logical_projection.rs | 9 + src/planner_v2/operator/mod.rs | 109 +++++ src/storage_v2/local_storage.rs | 99 +++++ src/storage_v2/mod.rs | 2 + src/types_v2/errors.rs | 9 + src/types_v2/mod.rs | 6 + src/types_v2/types.rs | 73 ++++ src/types_v2/values.rs | 401 ++++++++++++++++++ src/util/mod.rs | 20 +- tests/slt/aggregation.slt | 6 + tests/slt/alias.slt | 11 +- tests/slt/create_table.slt | 19 + tests/slt/distinct.slt | 6 + tests/slt/filter.slt | 3 + tests/slt/join.slt | 10 + tests/slt/join_filter.slt | 9 + tests/slt/limit.slt | 5 + tests/slt/order.slt | 3 + tests/slt/select.slt | 1 + tests/slt/subquery.slt | 13 + tests/sqllogictest/src/lib.rs | 27 ++ tests/sqllogictest/tests/sqllogictest.rs | 3 +- 94 files changed, 3298 insertions(+), 6 deletions(-) create mode 100644 src/catalog_v2/catalog.rs create mode 100644 src/catalog_v2/catalog_set.rs create mode 100644 src/catalog_v2/constants.rs create mode 100644 src/catalog_v2/entry/mod.rs create mode 100644 src/catalog_v2/entry/schema_catalog_entry.rs create mode 100644 src/catalog_v2/entry/table_catalog_entry.rs create mode 100644 src/catalog_v2/errors.rs create mode 100644 src/catalog_v2/mod.rs create mode 100644 src/execution/column_binding_resolver.rs create mode 100644 src/execution/expression_executor.rs create mode 100644 src/execution/mod.rs create mode 100644 src/execution/physical_plan/mod.rs create mode 100644 src/execution/physical_plan/physical_create_table.rs create mode 100644 src/execution/physical_plan/physical_expression_scan.rs create mode 100644 src/execution/physical_plan/physical_insert.rs create mode 100644 src/execution/physical_plan/physical_projection.rs create mode 100644 src/execution/physical_plan/physical_table_scan.rs create mode 100644 src/execution/physical_plan_generator.rs create mode 100644 src/execution/util.rs create mode 100644 src/execution/volcano_executor/create_table.rs create mode 100644 src/execution/volcano_executor/expression_scan.rs create mode 100644 src/execution/volcano_executor/insert.rs create mode 100644 src/execution/volcano_executor/mod.rs create mode 100644 src/execution/volcano_executor/projection.rs create mode 100644 src/execution/volcano_executor/table_scan.rs create mode 100644 src/main_entry/client_context.rs create mode 100644 src/main_entry/db.rs create mode 100644 src/main_entry/errors.rs create mode 100644 src/main_entry/mod.rs create mode 100644 src/main_entry/pending_query_result.rs create mode 100644 src/main_entry/prepared_statement_data.rs create mode 100644 src/main_entry/query_context.rs create mode 100644 src/main_entry/query_result.rs create mode 100644 src/planner_v2/binder/bind_context.rs create mode 100644 src/planner_v2/binder/binding.rs create mode 100644 src/planner_v2/binder/errors.rs create mode 100644 src/planner_v2/binder/expression/bind_column_ref_expression.rs create mode 100644 src/planner_v2/binder/expression/bind_constant_expression.rs create mode 100644 src/planner_v2/binder/expression/bind_reference_expression.rs create mode 100644 src/planner_v2/binder/expression/column_binding.rs create mode 100644 src/planner_v2/binder/expression/mod.rs create mode 100644 src/planner_v2/binder/mod.rs create mode 100644 src/planner_v2/binder/query_node/bind_select_node.rs create mode 100644 src/planner_v2/binder/query_node/mod.rs create mode 100644 src/planner_v2/binder/query_node/plan_select_node.rs create mode 100644 src/planner_v2/binder/statement/bind_create.rs create mode 100644 src/planner_v2/binder/statement/bind_insert.rs create mode 100644 src/planner_v2/binder/statement/bind_select.rs create mode 100644 src/planner_v2/binder/statement/create_info.rs create mode 100644 src/planner_v2/binder/statement/mod.rs create mode 100644 src/planner_v2/binder/tableref/bind_base_table_ref.rs create mode 100644 src/planner_v2/binder/tableref/bind_expression_list_ref.rs create mode 100644 src/planner_v2/binder/tableref/mod.rs create mode 100644 src/planner_v2/binder/tableref/plan_base_table_ref.rs create mode 100644 src/planner_v2/binder/tableref/plan_expression_list_ref.rs create mode 100644 src/planner_v2/binder/util.rs create mode 100644 src/planner_v2/constants.rs create mode 100644 src/planner_v2/errors.rs create mode 100644 src/planner_v2/expression_binder.rs create mode 100644 src/planner_v2/expression_iterator.rs create mode 100644 src/planner_v2/logical_operator_visitor.rs create mode 100644 src/planner_v2/mod.rs create mode 100644 src/planner_v2/operator/logical_create_table.rs create mode 100644 src/planner_v2/operator/logical_expression_get.rs create mode 100644 src/planner_v2/operator/logical_get.rs create mode 100644 src/planner_v2/operator/logical_insert.rs create mode 100644 src/planner_v2/operator/logical_projection.rs create mode 100644 src/planner_v2/operator/mod.rs create mode 100644 src/storage_v2/local_storage.rs create mode 100644 src/storage_v2/mod.rs create mode 100644 src/types_v2/errors.rs create mode 100644 src/types_v2/mod.rs create mode 100644 src/types_v2/types.rs create mode 100644 src/types_v2/values.rs create mode 100644 tests/slt/create_table.slt diff --git a/Cargo.lock b/Cargo.lock index 1528a4c..9157281 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -258,6 +258,17 @@ dependencies = [ "syn", ] +[[package]] +name = "derive-new" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3418329ca0ad70234b9735dc4ceed10af4df60eff9c8e7b06cb5e520d92c3535" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "diff" version = "0.1.13" @@ -1326,6 +1337,7 @@ dependencies = [ "ahash", "anyhow", "arrow", + "derive-new", "dirs", "downcast-rs", "enum_dispatch", diff --git a/Cargo.toml b/Cargo.toml index 3d9582d..ac3a6f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ lazy_static = "1" strum = "0.24" strum_macros = "0.24" ordered-float = "3.0" +derive-new = "0.5.9" [dev-dependencies] test-case = "2" diff --git a/src/catalog_v2/catalog.rs b/src/catalog_v2/catalog.rs new file mode 100644 index 0000000..30b22c6 --- /dev/null +++ b/src/catalog_v2/catalog.rs @@ -0,0 +1,60 @@ +use std::sync::Arc; + +use super::entry::{CatalogEntry, DataTable}; +use super::{CatalogError, CatalogSet, TableCatalogEntry}; +use crate::main_entry::ClientContext; + +/// The Catalog object represents the catalog of the database. +#[derive(Clone, Debug, Default)] +pub struct Catalog { + /// The catalog set holding the schemas + schemas: CatalogSet, + /// The catalog version, incremented whenever anything changes in the catalog + catalog_version: usize, +} + +impl Catalog { + pub fn create_schema(&mut self, name: String) -> Result<(), CatalogError> { + self.catalog_version += 1; + let entry = CatalogEntry::default_schema_catalog_entry(self.catalog_version, name.clone()); + self.schemas.create_entry(name, entry) + } + + pub fn create_table( + client_context: Arc, + schema: String, + table: String, + data_table: DataTable, + ) -> Result<(), CatalogError> { + let mut catalog = match client_context.db.catalog.try_write() { + Ok(c) => c, + Err(_) => return Err(CatalogError::CatalogLockedError), + }; + if let CatalogEntry::SchemaCatalogEntry(mut entry) = + catalog.schemas.get_entry(schema.clone())? + { + catalog.catalog_version += 1; + entry.create_table(catalog.catalog_version, table, data_table)?; + catalog + .schemas + .replace_entry(schema, CatalogEntry::SchemaCatalogEntry(entry))?; + return Ok(()); + } + Err(CatalogError::CatalogEntryTypeNotMatch) + } + + pub fn get_table( + client_context: Arc, + schema: String, + table: String, + ) -> Result { + let catalog = match client_context.db.catalog.try_read() { + Ok(c) => c, + Err(_) => return Err(CatalogError::CatalogLockedError), + }; + if let CatalogEntry::SchemaCatalogEntry(entry) = catalog.schemas.get_entry(schema)? { + return entry.get_table(table); + } + Err(CatalogError::CatalogEntryTypeNotMatch) + } +} diff --git a/src/catalog_v2/catalog_set.rs b/src/catalog_v2/catalog_set.rs new file mode 100644 index 0000000..ccac973 --- /dev/null +++ b/src/catalog_v2/catalog_set.rs @@ -0,0 +1,49 @@ +use std::collections::HashMap; + +use super::{CatalogEntry, CatalogError}; + +/// The Catalog Set stores (key, value) map of a set of CatalogEntries +#[derive(Clone, Debug, Default)] +pub struct CatalogSet { + /// The set of catalog entries, entry index to entry + entries: HashMap, + /// Mapping of string to catalog entry index + mapping: HashMap, + /// The current catalog entry index + current_entry: usize, +} + +impl CatalogSet { + pub fn create_entry(&mut self, name: String, entry: CatalogEntry) -> Result<(), CatalogError> { + if self.mapping.get(&name).is_some() { + return Err(CatalogError::CatalogEntryExists(name)); + } + self.current_entry += 1; + self.entries.insert(self.current_entry, entry); + self.mapping.insert(name, self.current_entry); + Ok(()) + } + + pub fn get_entry(&self, name: String) -> Result { + if let Some(index) = self.mapping.get(&name) { + if let Some(entry) = self.entries.get(index) { + return Ok(entry.clone()); + } + } + Err(CatalogError::CatalogEntryNotExists(name)) + } + + pub fn replace_entry( + &mut self, + name: String, + new_entry: CatalogEntry, + ) -> Result<(), CatalogError> { + if let Some(old_entry_index) = self.mapping.get(&name) { + if self.entries.get(old_entry_index).is_some() { + self.entries.insert(*old_entry_index, new_entry); + return Ok(()); + } + } + Err(CatalogError::CatalogEntryNotExists(name)) + } +} diff --git a/src/catalog_v2/constants.rs b/src/catalog_v2/constants.rs new file mode 100644 index 0000000..bb7d3bc --- /dev/null +++ b/src/catalog_v2/constants.rs @@ -0,0 +1 @@ +pub static DEFAULT_SCHEMA: &str = "main"; diff --git a/src/catalog_v2/entry/mod.rs b/src/catalog_v2/entry/mod.rs new file mode 100644 index 0000000..0233b3e --- /dev/null +++ b/src/catalog_v2/entry/mod.rs @@ -0,0 +1,27 @@ +mod schema_catalog_entry; +mod table_catalog_entry; + +use derive_new::new; +pub use schema_catalog_entry::*; +pub use table_catalog_entry::*; + +#[derive(Clone, Debug)] +pub enum CatalogEntry { + SchemaCatalogEntry(SchemaCatalogEntry), + TableCatalogEntry(TableCatalogEntry), +} + +impl CatalogEntry { + pub fn default_schema_catalog_entry(oid: usize, schema: String) -> Self { + Self::SchemaCatalogEntry(SchemaCatalogEntry::new(oid, schema)) + } +} + +#[allow(dead_code)] +#[derive(new, Clone, Debug)] +pub struct CatalogEntryBase { + /// The object identifier of the entry + oid: usize, + /// The name of the entry + name: String, +} diff --git a/src/catalog_v2/entry/schema_catalog_entry.rs b/src/catalog_v2/entry/schema_catalog_entry.rs new file mode 100644 index 0000000..25f3c07 --- /dev/null +++ b/src/catalog_v2/entry/schema_catalog_entry.rs @@ -0,0 +1,38 @@ +use super::table_catalog_entry::{DataTable, TableCatalogEntry}; +use super::{CatalogEntry, CatalogEntryBase}; +use crate::catalog_v2::{CatalogError, CatalogSet}; + +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub struct SchemaCatalogEntry { + base: CatalogEntryBase, + tables: CatalogSet, +} + +impl SchemaCatalogEntry { + pub fn new(oid: usize, schema: String) -> Self { + Self { + base: CatalogEntryBase::new(oid, schema), + tables: CatalogSet::default(), + } + } + + pub fn create_table( + &mut self, + oid: usize, + table: String, + storage: DataTable, + ) -> Result<(), CatalogError> { + let entry = + CatalogEntry::TableCatalogEntry(TableCatalogEntry::new(oid, table.clone(), storage)); + self.tables.create_entry(table, entry)?; + Ok(()) + } + + pub fn get_table(&self, table: String) -> Result { + match self.tables.get_entry(table.clone())? { + CatalogEntry::TableCatalogEntry(e) => Ok(e), + _ => Err(CatalogError::CatalogEntryNotExists(table)), + } + } +} diff --git a/src/catalog_v2/entry/table_catalog_entry.rs b/src/catalog_v2/entry/table_catalog_entry.rs new file mode 100644 index 0000000..097b547 --- /dev/null +++ b/src/catalog_v2/entry/table_catalog_entry.rs @@ -0,0 +1,64 @@ +use std::collections::HashMap; + +use derive_new::new; + +use super::CatalogEntryBase; +use crate::types_v2::LogicalType; + +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub struct TableCatalogEntry { + pub(crate) base: CatalogEntryBase, + pub(crate) storage: DataTable, + /// A list of columns that are part of this table + pub(crate) columns: Vec, + /// A map of column name to column index + pub(crate) name_map: HashMap, +} + +impl TableCatalogEntry { + pub fn new(oid: usize, table: String, storage: DataTable) -> Self { + let mut name_map = HashMap::new(); + let mut columns = vec![]; + storage + .column_definitions + .iter() + .enumerate() + .for_each(|(idx, col)| { + columns.push(col.clone()); + name_map.insert(col.name.clone(), idx); + }); + Self { + base: CatalogEntryBase::new(oid, table), + storage, + columns, + name_map, + } + } +} + +/// DataTable represents a physical table on disk +#[derive(new, Clone, Debug, PartialEq, Eq, Hash)] +pub struct DataTable { + /// The table info + pub(crate) info: DataTableInfo, + /// The set of physical columns stored by this DataTable + pub(crate) column_definitions: Vec, +} + +#[derive(new, Clone, Debug, PartialEq, Eq, Hash)] +pub struct DataTableInfo { + /// schema of the table + pub(crate) schema: String, + /// name of the table + pub(crate) table: String, +} + +/// A column of a table +#[derive(new, Clone, Debug, PartialEq, Eq, Hash)] +pub struct ColumnDefinition { + /// The name of the entry + pub(crate) name: String, + /// The type of the column + pub(crate) ty: LogicalType, +} diff --git a/src/catalog_v2/errors.rs b/src/catalog_v2/errors.rs new file mode 100644 index 0000000..e6e7758 --- /dev/null +++ b/src/catalog_v2/errors.rs @@ -0,0 +1,11 @@ +#[derive(thiserror::Error, Debug)] +pub enum CatalogError { + #[error("CatalogEntry: {0} already exists")] + CatalogEntryExists(String), + #[error("CatalogEntry: {0} not exists")] + CatalogEntryNotExists(String), + #[error("CatalogEntry type not match")] + CatalogEntryTypeNotMatch, + #[error("Catalog locked, please retry")] + CatalogLockedError, +} diff --git a/src/catalog_v2/mod.rs b/src/catalog_v2/mod.rs new file mode 100644 index 0000000..825fb86 --- /dev/null +++ b/src/catalog_v2/mod.rs @@ -0,0 +1,11 @@ +mod catalog; +mod catalog_set; +mod constants; +mod entry; +mod errors; + +pub use catalog::*; +pub use catalog_set::*; +pub use constants::*; +pub use entry::*; +pub use errors::*; diff --git a/src/cli.rs b/src/cli.rs index 5816e62..bd8eb0a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,16 +1,20 @@ use std::fs::File; +use std::sync::Arc; use anyhow::{Error, Result}; use rustyline::error::ReadlineError; use rustyline::Editor; +use crate::main_entry::ClientContext; use crate::util::pretty_batches; use crate::Database; -pub async fn interactive(db: Database) -> Result<()> { +pub async fn interactive(db: Database, client_context: Arc) -> Result<()> { let mut rl = Editor::<()>::new()?; load_history(&mut rl); + let mut enable_v2 = false; + loop { let read_sql = read_sql(&mut rl); match read_sql { @@ -19,7 +23,20 @@ pub async fn interactive(db: Database) -> Result<()> { rl.add_history_entry(sql.as_str()); let start_time = std::time::Instant::now(); - run_sql(&db, sql).await?; + if sql.starts_with("enable_v2") { + enable_v2 = true; + println!("---- enable sqlrs v2 ! ----"); + continue; + } + + if enable_v2 { + match client_context.query(sql).await { + Ok(_) => {} + Err(err) => println!("Run Error: {}", err), + } + } else { + run_sql(&db, sql).await?; + } let end_time = std::time::Instant::now(); let time_consumed = end_time.duration_since(start_time); diff --git a/src/execution/column_binding_resolver.rs b/src/execution/column_binding_resolver.rs new file mode 100644 index 0000000..8a00909 --- /dev/null +++ b/src/execution/column_binding_resolver.rs @@ -0,0 +1,39 @@ +use crate::planner_v2::{ + BoundColumnRefExpression, BoundExpression, BoundExpressionBase, BoundReferenceExpression, + ColumnBinding, LogicalOperator, LogicalOperatorVisitor, +}; + +#[derive(Default)] +pub struct ColumnBindingResolver { + bindings: Vec, +} + +impl LogicalOperatorVisitor for ColumnBindingResolver { + fn visit_operator(&mut self, op: &mut LogicalOperator) { + { + self.visit_operator_children(op); + self.visit_operator_expressions(op); + self.bindings = op.get_column_bindings(); + } + } + + fn visit_replace_column_ref(&self, expr: &BoundColumnRefExpression) -> Option { + assert!(expr.depth == 0); + // check the current set of column bindings to see which index corresponds to the column + // reference + if let Some(idx) = self.bindings.iter().position(|e| expr.binding == *e) { + let expr = BoundReferenceExpression::new( + BoundExpressionBase::new(expr.base.alias.clone(), expr.base.return_type.clone()), + idx, + ); + return Some(BoundExpression::BoundReferenceExpression(expr)); + } + + // could not bind the column reference, this should never happen and indicates a bug in the + // code generate an error message + panic!( + "Failed to bind column reference {} [{}.{}] (bindings: {:?}), ", + expr.base.alias, expr.binding.table_idx, expr.binding.column_idx, self.bindings + ); + } +} diff --git a/src/execution/expression_executor.rs b/src/execution/expression_executor.rs new file mode 100644 index 0000000..c648b0f --- /dev/null +++ b/src/execution/expression_executor.rs @@ -0,0 +1,33 @@ +use arrow::array::ArrayRef; +use arrow::record_batch::RecordBatch; + +use super::ExecutorError; +use crate::planner_v2::BoundExpression; + +/// ExpressionExecutor is responsible for executing a set of expressions and storing the result in a +/// data chunk +pub struct ExpressionExecutor; + +impl ExpressionExecutor { + pub fn execute( + expressions: &[BoundExpression], + input: &RecordBatch, + ) -> Result, ExecutorError> { + let mut result = vec![]; + for expr in expressions.iter() { + result.push(Self::execute_internal(expr, input)?); + } + Ok(result) + } + + fn execute_internal( + expr: &BoundExpression, + input: &RecordBatch, + ) -> Result { + Ok(match expr { + BoundExpression::BoundColumnRefExpression(_) => todo!(), + BoundExpression::BoundConstantExpression(e) => e.value.to_array(), + BoundExpression::BoundReferenceExpression(e) => input.column(e.index).clone(), + }) + } +} diff --git a/src/execution/mod.rs b/src/execution/mod.rs new file mode 100644 index 0000000..f464967 --- /dev/null +++ b/src/execution/mod.rs @@ -0,0 +1,55 @@ +mod column_binding_resolver; +mod expression_executor; +mod physical_plan; +mod physical_plan_generator; +mod volcano_executor; +use std::sync::Arc; +mod util; + +use arrow::error::ArrowError; +pub use column_binding_resolver::*; +use derive_new::new; +pub use expression_executor::*; +pub use physical_plan::*; +pub use physical_plan_generator::*; +pub use util::*; +pub use volcano_executor::*; + +use crate::catalog_v2::CatalogError; +use crate::main_entry::ClientContext; +use crate::types_v2::TypeError; + +#[derive(new)] +pub struct ExecutionContext { + pub(crate) client_context: Arc, +} + +impl ExecutionContext { + pub fn clone_client_context(&self) -> Arc { + self.client_context.clone() + } +} + +#[derive(thiserror::Error, Debug)] +pub enum ExecutorError { + #[error("catalog error: {0}")] + CatalogError( + #[source] + #[from] + CatalogError, + ), + #[error("arrow error: {0}")] + ArrowError( + #[source] + #[from] + ArrowError, + ), + #[error("type error: {0}")] + TypeError( + #[source] + #[from] + TypeError, + ), + #[error("Executor internal error: {0}")] + InternalError(String), +} diff --git a/src/execution/physical_plan/mod.rs b/src/execution/physical_plan/mod.rs new file mode 100644 index 0000000..226c234 --- /dev/null +++ b/src/execution/physical_plan/mod.rs @@ -0,0 +1,30 @@ +mod physical_create_table; +mod physical_expression_scan; +mod physical_insert; +mod physical_projection; +mod physical_table_scan; + +use derive_new::new; +pub use physical_create_table::*; +pub use physical_expression_scan::*; +pub use physical_insert::*; +pub use physical_projection::*; +pub use physical_table_scan::*; + +use crate::types_v2::LogicalType; + +#[derive(new, Default, Clone)] +pub struct PhysicalOperatorBase { + pub(crate) children: Vec, + /// The types returned by this physical operator + pub(crate) _types: Vec, +} + +#[derive(Clone)] +pub enum PhysicalOperator { + PhysicalCreateTable(PhysicalCreateTable), + PhysicalExpressionScan(PhysicalExpressionScan), + PhysicalInsert(PhysicalInsert), + PhysicalTableScan(PhysicalTableScan), + PhysicalProjection(PhysicalProjection), +} diff --git a/src/execution/physical_plan/physical_create_table.rs b/src/execution/physical_plan/physical_create_table.rs new file mode 100644 index 0000000..68a6699 --- /dev/null +++ b/src/execution/physical_plan/physical_create_table.rs @@ -0,0 +1,18 @@ +use derive_new::new; + +use super::{PhysicalOperator, PhysicalOperatorBase}; +use crate::execution::PhysicalPlanGenerator; +use crate::planner_v2::{BoundCreateTableInfo, LogicalCreateTable}; + +#[derive(new, Clone)] +pub struct PhysicalCreateTable { + #[new(default)] + _base: PhysicalOperatorBase, + pub(crate) info: BoundCreateTableInfo, +} + +impl PhysicalPlanGenerator { + pub(crate) fn create_physical_create_table(&self, op: LogicalCreateTable) -> PhysicalOperator { + PhysicalOperator::PhysicalCreateTable(PhysicalCreateTable::new(op.info)) + } +} diff --git a/src/execution/physical_plan/physical_expression_scan.rs b/src/execution/physical_plan/physical_expression_scan.rs new file mode 100644 index 0000000..625389c --- /dev/null +++ b/src/execution/physical_plan/physical_expression_scan.rs @@ -0,0 +1,29 @@ +use derive_new::new; + +use super::{PhysicalOperator, PhysicalOperatorBase}; +use crate::execution::PhysicalPlanGenerator; +use crate::planner_v2::{BoundExpression, LogicalExpressionGet}; +use crate::types_v2::LogicalType; + +/// The PhysicalExpressionScan scans a set of expressions +#[derive(new, Clone)] +pub struct PhysicalExpressionScan { + #[new(default)] + pub(crate) _base: PhysicalOperatorBase, + /// The types of the expressions + pub(crate) expr_types: Vec, + /// The set of expressions to scan + pub(crate) expressions: Vec>, +} + +impl PhysicalPlanGenerator { + pub(crate) fn create_physical_expression_scan( + &self, + op: LogicalExpressionGet, + ) -> PhysicalOperator { + PhysicalOperator::PhysicalExpressionScan(PhysicalExpressionScan::new( + op.expr_types, + op.expressions, + )) + } +} diff --git a/src/execution/physical_plan/physical_insert.rs b/src/execution/physical_plan/physical_insert.rs new file mode 100644 index 0000000..5cd130c --- /dev/null +++ b/src/execution/physical_plan/physical_insert.rs @@ -0,0 +1,46 @@ +use derive_new::new; + +use super::{PhysicalOperator, PhysicalOperatorBase}; +use crate::catalog_v2::TableCatalogEntry; +use crate::execution::PhysicalPlanGenerator; +use crate::planner_v2::LogicalInsert; +use crate::types_v2::LogicalType; + +#[derive(new, Clone)] +pub struct PhysicalInsert { + pub(crate) base: PhysicalOperatorBase, + /// The insertion map ([table_index -> index in result, or INVALID_INDEX if not specified]) + pub(crate) column_index_list: Vec, + /// The expected types for the INSERT statement + pub(crate) expected_types: Vec, + pub(crate) table: TableCatalogEntry, +} + +impl PhysicalInsert { + pub fn clone_with_base(&self, base: PhysicalOperatorBase) -> Self { + Self { + base, + column_index_list: self.column_index_list.clone(), + expected_types: self.expected_types.clone(), + table: self.table.clone(), + } + } +} + +impl PhysicalPlanGenerator { + pub(crate) fn create_physical_insert(&self, op: LogicalInsert) -> PhysicalOperator { + let new_children = op + .base + .children + .into_iter() + .map(|op| self.create_plan_internal(op)) + .collect::>(); + let base = PhysicalOperatorBase::new(new_children, op.base.types); + PhysicalOperator::PhysicalInsert(PhysicalInsert::new( + base, + op.column_index_list, + op.expected_types, + op.table, + )) + } +} diff --git a/src/execution/physical_plan/physical_projection.rs b/src/execution/physical_plan/physical_projection.rs new file mode 100644 index 0000000..133e802 --- /dev/null +++ b/src/execution/physical_plan/physical_projection.rs @@ -0,0 +1,25 @@ +use derive_new::new; + +use super::{PhysicalOperator, PhysicalOperatorBase}; +use crate::execution::PhysicalPlanGenerator; +use crate::planner_v2::{BoundExpression, LogicalProjection}; + +#[derive(new, Clone)] +pub struct PhysicalProjection { + pub(crate) base: PhysicalOperatorBase, + pub(crate) select_list: Vec, +} + +impl PhysicalPlanGenerator { + pub(crate) fn create_physical_projection(&self, op: LogicalProjection) -> PhysicalOperator { + let new_children = op + .base + .children + .into_iter() + .map(|p| self.create_plan_internal(p)) + .collect::>(); + let types = op.base.types; + let base = PhysicalOperatorBase::new(new_children, types); + PhysicalOperator::PhysicalProjection(PhysicalProjection::new(base, op.base.expressioins)) + } +} diff --git a/src/execution/physical_plan/physical_table_scan.rs b/src/execution/physical_plan/physical_table_scan.rs new file mode 100644 index 0000000..ee0112b --- /dev/null +++ b/src/execution/physical_plan/physical_table_scan.rs @@ -0,0 +1,25 @@ +use derive_new::new; + +use super::{PhysicalOperator, PhysicalOperatorBase}; +use crate::catalog_v2::TableCatalogEntry; +use crate::execution::PhysicalPlanGenerator; +use crate::planner_v2::LogicalGet; +use crate::types_v2::LogicalType; + +#[derive(new, Clone)] +pub struct PhysicalTableScan { + pub(crate) _base: PhysicalOperatorBase, + pub(crate) bind_table: TableCatalogEntry, + /// The types of ALL columns that can be returned by the table function + pub(crate) returned_types: Vec, + /// The names of ALL columns that can be returned by the table function + pub(crate) names: Vec, +} + +impl PhysicalPlanGenerator { + pub(crate) fn create_physical_table_scan(&self, op: LogicalGet) -> PhysicalOperator { + let base = PhysicalOperatorBase::new(vec![], op.base.types); + let plan = PhysicalTableScan::new(base, op.bind_table, op.returned_types, op.names); + PhysicalOperator::PhysicalTableScan(plan) + } +} diff --git a/src/execution/physical_plan_generator.rs b/src/execution/physical_plan_generator.rs new file mode 100644 index 0000000..de102f0 --- /dev/null +++ b/src/execution/physical_plan_generator.rs @@ -0,0 +1,36 @@ +use std::sync::Arc; + +use derive_new::new; + +use super::{ColumnBindingResolver, PhysicalOperator}; +use crate::main_entry::ClientContext; +use crate::planner_v2::{LogicalOperator, LogicalOperatorVisitor}; + +#[derive(new)] +pub struct PhysicalPlanGenerator { + pub(crate) _client_context: Arc, +} + +impl PhysicalPlanGenerator { + pub(crate) fn create_plan(&self, mut op: LogicalOperator) -> PhysicalOperator { + // first resolve column references + let mut resolver = ColumnBindingResolver::default(); + resolver.visit_operator(&mut op); + + // now resolve types of all the operators + op.resolve_operator_types(); + + // then create the main physical plan + self.create_plan_internal(op) + } + + pub(crate) fn create_plan_internal(&self, op: LogicalOperator) -> PhysicalOperator { + match op { + LogicalOperator::LogicalCreateTable(op) => self.create_physical_create_table(op), + LogicalOperator::LogicalExpressionGet(op) => self.create_physical_expression_scan(op), + LogicalOperator::LogicalInsert(op) => self.create_physical_insert(op), + LogicalOperator::LogicalGet(op) => self.create_physical_table_scan(op), + LogicalOperator::LogicalProjection(op) => self.create_physical_projection(op), + } + } +} diff --git a/src/execution/util.rs b/src/execution/util.rs new file mode 100644 index 0000000..64d144c --- /dev/null +++ b/src/execution/util.rs @@ -0,0 +1,27 @@ +use std::collections::HashMap; + +use arrow::datatypes::{Field, Schema, SchemaRef}; + +use crate::planner_v2::BoundExpression; +use crate::types_v2::LogicalType; + +pub struct SchemaUtil; + +impl SchemaUtil { + pub fn new_schema_ref(names: &[String], types: &[LogicalType]) -> SchemaRef { + let fields = names + .iter() + .zip(types.iter()) + .map(|(name, ty)| Field::new(name, ty.clone().into(), true)) + .collect::>(); + SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new())) + } + + pub fn new_schema_ref_from_exprs(exprs: &[BoundExpression]) -> SchemaRef { + let fields = exprs + .iter() + .map(|e| Field::new(&e.alias(), e.return_type().into(), true)) + .collect::>(); + SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new())) + } +} diff --git a/src/execution/volcano_executor/create_table.rs b/src/execution/volcano_executor/create_table.rs new file mode 100644 index 0000000..8350ecb --- /dev/null +++ b/src/execution/volcano_executor/create_table.rs @@ -0,0 +1,41 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array::StringArray; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use derive_new::new; +use futures_async_stream::try_stream; + +use crate::catalog_v2::{Catalog, DataTable, DataTableInfo}; +use crate::execution::{ExecutionContext, ExecutorError, PhysicalCreateTable}; + +#[derive(new)] +pub struct CreateTable { + pub(crate) plan: PhysicalCreateTable, +} + +impl CreateTable { + #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] + pub async fn execute(self, context: Arc) { + let schema = self.plan.info.base.base.schema; + let table = self.plan.info.base.table; + let column_definitions = self.plan.info.base.columns; + let data_table = DataTable::new( + DataTableInfo::new(schema.clone(), table.clone()), + column_definitions, + ); + Catalog::create_table( + context.clone_client_context(), + schema, + table.clone(), + data_table, + )?; + let array = Arc::new(StringArray::from(vec![format!("CREATE TABLE {}", table)])); + let fields = vec![Field::new("success", DataType::Utf8, false)]; + yield RecordBatch::try_new( + SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new())), + vec![array], + )?; + } +} diff --git a/src/execution/volcano_executor/expression_scan.rs b/src/execution/volcano_executor/expression_scan.rs new file mode 100644 index 0000000..9dfeb95 --- /dev/null +++ b/src/execution/volcano_executor/expression_scan.rs @@ -0,0 +1,36 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use derive_new::new; +use futures_async_stream::try_stream; + +use crate::execution::{ + ExecutionContext, ExecutorError, ExpressionExecutor, PhysicalExpressionScan, +}; + +#[derive(new)] +pub struct ExpressionScan { + pub(crate) plan: PhysicalExpressionScan, +} + +impl ExpressionScan { + #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] + pub async fn execute(self, _context: Arc) { + let mut fields = vec![]; + for (idx, ty) in self.plan.expr_types.iter().enumerate() { + fields.push(Field::new( + format!("col{}", idx).as_str(), + ty.clone().into(), + true, + )); + } + let schema = SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new())); + let input = RecordBatch::new_empty(schema.clone()); + for exprs in self.plan.expressions.iter() { + let columns = ExpressionExecutor::execute(exprs, &input)?; + yield RecordBatch::try_new(schema.clone(), columns)?; + } + } +} diff --git a/src/execution/volcano_executor/insert.rs b/src/execution/volcano_executor/insert.rs new file mode 100644 index 0000000..ed44c3f --- /dev/null +++ b/src/execution/volcano_executor/insert.rs @@ -0,0 +1,62 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use derive_new::new; +use futures_async_stream::try_stream; + +use crate::execution::{ + BoxedExecutor, ExecutionContext, ExecutorError, ExpressionExecutor, PhysicalInsert, +}; +use crate::planner_v2::{ + BoundConstantExpression, BoundExpression, BoundExpressionBase, BoundReferenceExpression, + INVALID_INDEX, +}; +use crate::storage_v2::LocalStorage; +use crate::types_v2::ScalarValue; + +#[derive(new)] +pub struct Insert { + pub(crate) plan: PhysicalInsert, + pub(crate) child: BoxedExecutor, +} + +impl Insert { + #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] + pub async fn execute(self, context: Arc) { + let table = self.plan.table.storage; + let mut exprs = vec![]; + let mut fields = vec![]; + for (table_col_idx, col_insert_idx) in self.plan.column_index_list.iter().enumerate() { + let column = table.column_definitions[table_col_idx].clone(); + fields.push(Field::new( + column.name.as_str(), + column.ty.clone().into(), + true, + )); + let ty = column.ty.clone(); + let base = BoundExpressionBase::new("".to_string(), ty.clone()); + if *col_insert_idx == INVALID_INDEX { + let value = ScalarValue::new_none_value(&ty.into())?; + let expr = BoundExpression::BoundConstantExpression(BoundConstantExpression::new( + base, value, + )); + exprs.push(expr); + } else { + let expr = BoundExpression::BoundReferenceExpression( + BoundReferenceExpression::new(base, *col_insert_idx), + ); + exprs.push(expr); + } + } + let schema = SchemaRef::new(Schema::new_with_metadata(fields.clone(), HashMap::new())); + #[for_await] + for batch in self.child { + let batch = batch?; + let columns = ExpressionExecutor::execute(&exprs, &batch)?; + let chunk = RecordBatch::try_new(schema.clone(), columns)?; + LocalStorage::append(context.clone_client_context(), &table, chunk); + } + } +} diff --git a/src/execution/volcano_executor/mod.rs b/src/execution/volcano_executor/mod.rs new file mode 100644 index 0000000..3b0b2d0 --- /dev/null +++ b/src/execution/volcano_executor/mod.rs @@ -0,0 +1,61 @@ +mod create_table; +mod expression_scan; +mod insert; +mod projection; +mod table_scan; +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; +pub use create_table::*; +pub use expression_scan::*; +use futures::stream::BoxStream; +use futures::TryStreamExt; +pub use insert::*; +pub use projection::*; +pub use table_scan::*; + +use super::{ExecutionContext, ExecutorError, PhysicalOperator}; + +pub type BoxedExecutor = BoxStream<'static, Result>; + +#[derive(Default)] +pub struct VolcanoExecutor {} + +impl VolcanoExecutor { + pub fn new() -> Self { + VolcanoExecutor::default() + } + + fn build(&self, plan: PhysicalOperator, context: Arc) -> BoxedExecutor { + match plan { + PhysicalOperator::PhysicalCreateTable(op) => CreateTable::new(op).execute(context), + PhysicalOperator::PhysicalExpressionScan(op) => { + ExpressionScan::new(op).execute(context) + } + PhysicalOperator::PhysicalInsert(op) => { + let child = op.base.children.first().unwrap().clone(); + let child_executor = self.build(child, context.clone()); + Insert::new(op, child_executor).execute(context) + } + PhysicalOperator::PhysicalTableScan(op) => TableScan::new(op).execute(context), + PhysicalOperator::PhysicalProjection(op) => { + let child = op.base.children.first().unwrap().clone(); + let child_executor = self.build(child, context.clone()); + Projection::new(op, child_executor).execute(context) + } + } + } + + pub(crate) async fn try_execute( + &self, + plan: PhysicalOperator, + context: Arc, + ) -> Result, ExecutorError> { + let mut output = Vec::new(); + let mut volcano_executor = self.build(plan, context.clone()); + while let Some(batch) = volcano_executor.try_next().await? { + output.push(batch); + } + Ok(output) + } +} diff --git a/src/execution/volcano_executor/projection.rs b/src/execution/volcano_executor/projection.rs new file mode 100644 index 0000000..ac32b8e --- /dev/null +++ b/src/execution/volcano_executor/projection.rs @@ -0,0 +1,31 @@ +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; +use derive_new::new; +use futures_async_stream::try_stream; + +use crate::execution::{ + BoxedExecutor, ExecutionContext, ExecutorError, ExpressionExecutor, PhysicalProjection, + SchemaUtil, +}; + +#[derive(new)] +pub struct Projection { + pub(crate) plan: PhysicalProjection, + pub(crate) child: BoxedExecutor, +} + +impl Projection { + #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] + pub async fn execute(self, _context: Arc) { + let exprs = self.plan.select_list; + let schema = SchemaUtil::new_schema_ref_from_exprs(&exprs); + + #[for_await] + for batch in self.child { + let batch = batch?; + let columns = ExpressionExecutor::execute(&exprs, &batch)?; + yield RecordBatch::try_new(schema.clone(), columns)?; + } + } +} diff --git a/src/execution/volcano_executor/table_scan.rs b/src/execution/volcano_executor/table_scan.rs new file mode 100644 index 0000000..b67f2a2 --- /dev/null +++ b/src/execution/volcano_executor/table_scan.rs @@ -0,0 +1,28 @@ +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; +use derive_new::new; +use futures_async_stream::try_stream; + +use crate::execution::{ExecutionContext, ExecutorError, PhysicalTableScan, SchemaUtil}; +use crate::storage_v2::LocalStorage; + +#[derive(new)] +pub struct TableScan { + pub(crate) plan: PhysicalTableScan, +} + +impl TableScan { + #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] + pub async fn execute(self, context: Arc) { + let schema = SchemaUtil::new_schema_ref(&self.plan.names, &self.plan.returned_types); + + let table = self.plan.bind_table; + let mut local_storage_reader = LocalStorage::create_reader(&table.storage); + let client_context = context.clone_client_context(); + while let Some(batch) = local_storage_reader.next_batch(client_context.clone()) { + let columns = batch.columns().to_vec(); + yield RecordBatch::try_new(schema.clone(), columns)? + } + } +} diff --git a/src/lib.rs b/src/lib.rs index c6e0c26..7cc7ca3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,14 +9,20 @@ extern crate lazy_static; pub mod binder; pub mod catalog; +pub mod catalog_v2; pub mod cli; pub mod db; +pub mod execution; pub mod executor; +pub mod main_entry; pub mod optimizer; pub mod parser; pub mod planner; +pub mod planner_v2; pub mod storage; +pub mod storage_v2; pub mod types; +pub mod types_v2; pub mod util; pub use self::db::{Database, DatabaseError}; diff --git a/src/main.rs b/src/main.rs index 3da8670..a347095 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,7 @@ +use std::sync::Arc; + use anyhow::Result; +use sqlrs::main_entry::{ClientContext, DatabaseInstance}; use sqlrs::{cli, Database}; #[tokio::main] @@ -10,7 +13,9 @@ async fn main() -> Result<()> { create_csv_table(&db, "t1")?; create_csv_table(&db, "t2")?; - cli::interactive(db).await?; + let dbv2 = Arc::new(DatabaseInstance::default()); + let client_context = ClientContext::new(dbv2); + cli::interactive(db, client_context).await?; Ok(()) } diff --git a/src/main_entry/client_context.rs b/src/main_entry/client_context.rs new file mode 100644 index 0000000..e7c91a0 --- /dev/null +++ b/src/main_entry/client_context.rs @@ -0,0 +1,124 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; +use futures::lock::Mutex; +use sqlparser::ast::Statement; + +use super::query_context::ActiveQueryContext; +use super::{ + DatabaseError, DatabaseInstance, PendingQueryResult, PreparedStatementData, QueryResult, +}; +use crate::execution::{PhysicalPlanGenerator, VolcanoExecutor}; +use crate::parser::Sqlparser; +use crate::planner_v2::Planner; +use crate::util::pretty_batches_with; + +/// The ClientContext holds information relevant to the current client session during execution +pub struct ClientContext { + /// The database that this client is connected to + pub(crate) db: Arc, + pub(crate) active_query: Mutex, + pub(crate) interrupted: AtomicBool, +} + +impl ClientContext { + pub fn new(db: Arc) -> Arc { + Arc::new(Self { + db, + active_query: Mutex::new(ActiveQueryContext::default()), + interrupted: AtomicBool::new(false), + }) + } + + pub async fn query(self: &Arc, sql: String) -> Result, DatabaseError> { + let statements = Sqlparser::parse(sql.clone())?; + if statements.is_empty() { + return Err(DatabaseError::InternalError( + "invalid statement".to_string(), + )); + } + + let mut collection_result = vec![]; + + for stat in statements.iter() { + let result = self.pending_query(stat).await?; + match result { + QueryResult::MaterializedQueryResult(res) => { + pretty_batches_with(&res.collection, &res.base.names, &res.base.types); + collection_result.extend(res.collection); + } + } + } + + Ok(collection_result) + } + + async fn pending_query( + self: &Arc, + statement: &Statement, + ) -> Result { + let pending_query = self + .pending_statement_or_prepared_statement(statement) + .await?; + let result = pending_query.execute().await?; + Ok(result) + } + + async fn pending_statement_or_prepared_statement( + self: &Arc, + statement: &Statement, + ) -> Result, DatabaseError> { + self.initial_cleanup().await; + + self.active_query.lock().await.query = Some(statement.to_string()); + // prepare the query for execution + let prepared = self.create_prepared_statement(statement).await?; + self.active_query.lock().await.prepared = Some(prepared); + // set volcano executor + let executor = VolcanoExecutor::new(); + self.active_query.lock().await.executor = Some(executor); + // return pending query result + let pending_query_result = Arc::new(PendingQueryResult::new(self.clone())); + self.active_query.lock().await.open_result = Some(pending_query_result.clone()); + Ok(pending_query_result.clone()) + } + + async fn create_prepared_statement( + self: &Arc, + statement: &Statement, + ) -> Result { + let mut planner = Planner::new(self.clone()); + planner.create_plan(statement)?; + let logical_plan = planner.plan.unwrap(); + let names = planner.names.unwrap(); + let types = planner.types.unwrap(); + + let physical_planner = PhysicalPlanGenerator::new(self.clone()); + let physical_plan = physical_planner.create_plan(logical_plan); + + let result = PreparedStatementData::new(statement.clone(), physical_plan, names, types); + Ok(result) + } + + async fn initial_cleanup(self: &Arc) { + self.cleanup_internal().await; + self.interrupted.store(false, Ordering::Release); + } + + async fn cleanup_internal(self: &Arc) { + self.active_query.lock().await.reset(); + } + + pub async fn is_active_request(self: &Arc, query_result: &PendingQueryResult) -> bool { + let active_query_context = self.active_query.lock().await; + if active_query_context.is_empty() { + return false; + } + if let Some(open_result) = &active_query_context.open_result { + std::ptr::eq(open_result.as_ref(), query_result) + } else { + false + } + } +} diff --git a/src/main_entry/db.rs b/src/main_entry/db.rs new file mode 100644 index 0000000..daaad42 --- /dev/null +++ b/src/main_entry/db.rs @@ -0,0 +1,20 @@ +use std::sync::{Arc, RwLock}; + +use crate::catalog_v2::{Catalog, DEFAULT_SCHEMA}; +use crate::storage_v2::LocalStorage; + +pub struct DatabaseInstance { + pub(crate) storage: RwLock, + pub(crate) catalog: Arc>, +} + +impl Default for DatabaseInstance { + fn default() -> Self { + let mut catalog = Catalog::default(); + catalog.create_schema(DEFAULT_SCHEMA.to_string()).unwrap(); + Self { + storage: RwLock::new(LocalStorage::default()), + catalog: Arc::new(RwLock::new(catalog)), + } + } +} diff --git a/src/main_entry/errors.rs b/src/main_entry/errors.rs new file mode 100644 index 0000000..fa71b4b --- /dev/null +++ b/src/main_entry/errors.rs @@ -0,0 +1,35 @@ +use arrow::error::ArrowError; +use sqlparser::parser::ParserError; + +use crate::execution::ExecutorError; +use crate::planner_v2::PlannerError; + +#[derive(thiserror::Error, Debug)] +pub enum DatabaseError { + #[error("parse error: {0}")] + ParserError( + #[source] + #[from] + ParserError, + ), + #[error("planner error: {0}")] + PlannerError( + #[source] + #[from] + PlannerError, + ), + #[error("executor error: {0}")] + ExecutorError( + #[source] + #[from] + ExecutorError, + ), + #[error("Arrow error: {0}")] + ArrowError( + #[source] + #[from] + ArrowError, + ), + #[error("Internal error: {0}")] + InternalError(String), +} diff --git a/src/main_entry/mod.rs b/src/main_entry/mod.rs new file mode 100644 index 0000000..babadd2 --- /dev/null +++ b/src/main_entry/mod.rs @@ -0,0 +1,14 @@ +mod client_context; +mod db; +mod errors; +mod pending_query_result; +mod prepared_statement_data; +mod query_context; +mod query_result; + +pub use client_context::*; +pub use db::*; +pub use errors::*; +pub use pending_query_result::*; +pub use prepared_statement_data::*; +pub use query_result::*; diff --git a/src/main_entry/pending_query_result.rs b/src/main_entry/pending_query_result.rs new file mode 100644 index 0000000..d237627 --- /dev/null +++ b/src/main_entry/pending_query_result.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use derive_new::new; + +use super::{BaseQueryResult, ClientContext, DatabaseError, MaterializedQueryResult, QueryResult}; +use crate::execution::ExecutionContext; + +#[derive(new)] +pub struct PendingQueryResult { + pub(crate) client_context: Arc, +} + +impl PendingQueryResult { + pub async fn execute(&self) -> Result { + self.check_executable_internal().await?; + + let mut active_query_context = self.client_context.active_query.lock().await; + let executor = active_query_context.executor.take().unwrap(); + let prepared = active_query_context.prepared.take().unwrap(); + // execute the query + let execution_context = Arc::new(ExecutionContext::new(self.client_context.clone())); + let collection = executor + .try_execute(prepared.plan, execution_context) + .await?; + // set query result + let materialized_query_result = MaterializedQueryResult::new( + BaseQueryResult::new(prepared.types, prepared.names), + collection, + ); + Ok(QueryResult::MaterializedQueryResult( + materialized_query_result, + )) + } + + async fn check_executable_internal(&self) -> Result<(), DatabaseError> { + // whether the current pending query is active or not + let invalidated = !self.client_context.is_active_request(self).await; + if invalidated { + return Err(DatabaseError::InternalError( + "Attempting to execute an unsuccessful or closed pending query result".to_string(), + )); + } + Ok(()) + } +} diff --git a/src/main_entry/prepared_statement_data.rs b/src/main_entry/prepared_statement_data.rs new file mode 100644 index 0000000..04f1d16 --- /dev/null +++ b/src/main_entry/prepared_statement_data.rs @@ -0,0 +1,18 @@ +use derive_new::new; +use sqlparser::ast::Statement; + +use crate::execution::PhysicalOperator; +use crate::types_v2::LogicalType; + +#[derive(new)] +#[allow(dead_code)] +pub struct PreparedStatementData { + /// The unbound SQL statement that was prepared + pub(crate) unbound_statement: Statement, + /// The fully prepared physical plan of the prepared statement + pub(crate) plan: PhysicalOperator, + /// The result names + pub(crate) names: Vec, + /// The result types + pub(crate) types: Vec, +} diff --git a/src/main_entry/query_context.rs b/src/main_entry/query_context.rs new file mode 100644 index 0000000..d06f6f9 --- /dev/null +++ b/src/main_entry/query_context.rs @@ -0,0 +1,32 @@ +use std::sync::Arc; + +use super::{PendingQueryResult, PreparedStatementData}; +use crate::execution::VolcanoExecutor; + +#[derive(Default)] +pub struct ActiveQueryContext { + /// The query that is currently being executed + pub(crate) query: Option, + /// The currently open result + pub(crate) open_result: Option>, + /// Prepared statement data + pub(crate) prepared: Option, + /// The query executor + pub(crate) executor: Option, +} + +impl ActiveQueryContext { + pub fn reset(&mut self) { + self.query = None; + self.open_result = None; + self.prepared = None; + self.executor = None; + } + + pub fn is_empty(&self) -> bool { + self.query.is_none() + && self.open_result.is_none() + && self.prepared.is_none() + && self.executor.is_none() + } +} diff --git a/src/main_entry/query_result.rs b/src/main_entry/query_result.rs new file mode 100644 index 0000000..4531270 --- /dev/null +++ b/src/main_entry/query_result.rs @@ -0,0 +1,22 @@ +use arrow::record_batch::RecordBatch; +use derive_new::new; + +use crate::types_v2::LogicalType; + +#[derive(new, Debug)] +pub struct BaseQueryResult { + /// The SQL types of the result + pub(crate) types: Vec, + /// The names of the result + pub(crate) names: Vec, +} + +#[derive(new)] +pub struct MaterializedQueryResult { + pub(crate) base: BaseQueryResult, + pub(crate) collection: Vec, +} + +pub enum QueryResult { + MaterializedQueryResult(MaterializedQueryResult), +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 0e76d74..814566d 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10,3 +10,13 @@ pub fn parse(sql: &str) -> Result, ParserError> { } Ok(stmts) } + +pub struct Sqlparser {} + +impl Sqlparser { + pub fn parse(sql: String) -> Result, ParserError> { + let dialect = PostgreSqlDialect {}; + let stmts = Parser::parse_sql(&dialect, sql.as_str())?; + Ok(stmts) + } +} diff --git a/src/planner_v2/binder/bind_context.rs b/src/planner_v2/binder/bind_context.rs new file mode 100644 index 0000000..4c269b0 --- /dev/null +++ b/src/planner_v2/binder/bind_context.rs @@ -0,0 +1,102 @@ +use std::collections::HashMap; + +use derive_new::new; + +use super::{BindError, Binding, BoundColumnRefExpression}; +use crate::catalog_v2::CatalogEntry; +use crate::types_v2::LogicalType; + +/// The BindContext object keeps track of all the tables and columns +/// that are encountered during the binding process. +#[derive(new, Debug, Clone)] +pub struct BindContext { + /// table name -> table binding + #[new(default)] + pub(crate) bindings: HashMap, + #[new(default)] + pub(crate) binding_list: Vec, +} + +impl BindContext { + pub fn add_binding( + &mut self, + alias: String, + index: usize, + types: Vec, + names: Vec, + catalog_entry: Option, + ) { + let name_map = names + .iter() + .enumerate() + .map(|(i, name)| (name.clone(), i)) + .collect(); + let mut binding = Binding::new(alias.clone(), index, types, names, name_map); + binding.catalog_entry = catalog_entry; + self.bindings.insert(alias, binding.clone()); + self.binding_list.push(binding); + } + + pub fn add_generic_binding( + &mut self, + alias: String, + index: usize, + types: Vec, + names: Vec, + ) { + self.add_binding(alias, index, types, names, None); + } + + pub fn add_base_table( + &mut self, + alias: String, + index: usize, + types: Vec, + names: Vec, + catalog_entry: CatalogEntry, + ) { + self.add_binding(alias, index, types, names, Some(catalog_entry)); + } + + pub fn get_binding(&self, table_name: &str) -> Option { + self.bindings.get(table_name).cloned() + } + + pub fn get_matching_binding(&self, column_name: &str) -> Result { + let mut mathing_table_name = None; + for binding in self.binding_list.iter() { + if binding.has_match_binding(column_name) { + if mathing_table_name.is_some() { + return Err(BindError::Internal(format!( + "Ambiguous column name {}", + column_name + ))); + } + mathing_table_name = Some(binding.alias.clone()); + } + } + if let Some(table_name) = mathing_table_name { + Ok(table_name) + } else { + Err(BindError::Internal(format!( + "Column {} not found in any table", + column_name + ))) + } + } + + pub fn bind_column( + &mut self, + table_name: &str, + column_name: &str, + ) -> Result { + if let Some(table_binding) = self.get_binding(table_name) { + table_binding.bind_column(column_name, 0) + } else { + Err(BindError::Internal(format!( + "Table {} not found in context", + table_name + ))) + } + } +} diff --git a/src/planner_v2/binder/binding.rs b/src/planner_v2/binder/binding.rs new file mode 100644 index 0000000..07b27f0 --- /dev/null +++ b/src/planner_v2/binder/binding.rs @@ -0,0 +1,55 @@ +use std::collections::HashMap; + +use derive_new::new; + +use super::{BindError, BoundColumnRefExpression, BoundExpressionBase, ColumnBinding}; +use crate::catalog_v2::CatalogEntry; +use crate::types_v2::LogicalType; + +/// A Binding represents a binding to a table, table-producing function +/// or subquery with a specified table index. +#[derive(new, Clone, Debug)] +pub struct Binding { + /// The alias of the binding + pub(crate) alias: String, + /// The table index of the binding + pub(crate) index: usize, + pub(crate) types: Vec, + #[allow(dead_code)] + pub(crate) names: Vec, + /// Name -> index for the names + pub(crate) name_map: HashMap, + /// The underlying catalog entry (if any) + #[new(default)] + pub(crate) catalog_entry: Option, +} +impl Binding { + pub fn has_match_binding(&self, column_name: &str) -> bool { + self.try_get_binding_index(column_name).is_some() + } + + pub fn try_get_binding_index(&self, column_name: &str) -> Option { + self.name_map.get(column_name).cloned() + } + + pub fn bind_column( + &self, + column_name: &str, + depth: usize, + ) -> Result { + if let Some(col_idx) = self.try_get_binding_index(column_name) { + let col_type = self.types[col_idx].clone(); + let col_binding = ColumnBinding::new(self.index, col_idx); + Ok(BoundColumnRefExpression::new( + BoundExpressionBase::new(column_name.to_string(), col_type), + col_binding, + depth, + )) + } else { + Err(BindError::Internal(format!( + "Column {} not found in table {}", + column_name, self.alias + ))) + } + } +} diff --git a/src/planner_v2/binder/errors.rs b/src/planner_v2/binder/errors.rs new file mode 100644 index 0000000..83075bc --- /dev/null +++ b/src/planner_v2/binder/errors.rs @@ -0,0 +1,23 @@ +#[derive(thiserror::Error, Debug)] +pub enum BindError { + #[error("unsupported expr: {0}")] + UnsupportedExpr(String), + #[error("unsupported statement: {0}")] + UnsupportedStmt(String), + #[error("sqlparser unsupported statement: {0}")] + SqlParserUnsupportedStmt(String), + #[error("bind internal error: {0}")] + Internal(String), + #[error("type error: {0}")] + TypeError( + #[from] + #[source] + crate::types_v2::TypeError, + ), + #[error("catalog error: {0}")] + CatalogError( + #[from] + #[source] + crate::catalog_v2::CatalogError, + ), +} diff --git a/src/planner_v2/binder/expression/bind_column_ref_expression.rs b/src/planner_v2/binder/expression/bind_column_ref_expression.rs new file mode 100644 index 0000000..b4c79c1 --- /dev/null +++ b/src/planner_v2/binder/expression/bind_column_ref_expression.rs @@ -0,0 +1,74 @@ +use derive_new::new; +use itertools::Itertools; + +use super::{BoundExpression, BoundExpressionBase, ColumnBinding}; +use crate::planner_v2::{BindError, ExpressionBinder}; +use crate::types_v2::LogicalType; + +/// A BoundColumnRef expression represents a ColumnRef expression that was bound to an actual table +/// and column index. It is not yet executable, however. The ColumnBindingResolver transforms the +/// BoundColumnRefExpressions into BoundReferenceExpressions, which refer to indexes into the +/// physical chunks that pass through the executor. +#[derive(new, Debug, Clone)] +pub struct BoundColumnRefExpression { + pub(crate) base: BoundExpressionBase, + /// Column index set by the binder, used to generate the final BoundReferenceExpression + pub(crate) binding: ColumnBinding, + /// The subquery depth (i.e. depth 0 = current query, depth 1 = parent query, depth 2 = parent + /// of parent, etc...). This is only non-zero for correlated expressions inside subqueries. + pub(crate) depth: usize, +} + +impl ExpressionBinder<'_> { + /// qualify column name with existing table name + fn qualify_column_name( + &self, + table_name: Option<&String>, + column_name: &String, + ) -> Result<(String, String), BindError> { + if let Some(table_name) = table_name { + Ok((table_name.to_string(), column_name.to_string())) + } else { + let table_name = self.binder.bind_context.get_matching_binding(column_name)?; + Ok((table_name, column_name.to_string())) + } + } + + pub fn bind_column_ref_expr( + &mut self, + idents: &[sqlparser::ast::Ident], + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + let idents = idents + .iter() + .map(|ident| ident.value.to_lowercase()) + .collect_vec(); + + let (_schema_name, table_name, column_name) = match idents.as_slice() { + [column] => (None, None, column), + [table, column] => (None, Some(table), column), + [schema, table, column] => (Some(schema), Some(table), column), + _ => return Err(BindError::UnsupportedExpr(format!("{:?}", idents))), + }; + + let (table_name, column_name) = self.qualify_column_name(table_name, column_name)?; + + // check table_name, and column_name + if self.binder.has_match_binding(&table_name, &column_name) { + let bound_col_ref = self + .binder + .bind_context + .bind_column(&table_name, &column_name)?; + result_names.push(bound_col_ref.base.alias.clone()); + result_types.push(bound_col_ref.base.return_type.clone()); + Ok(BoundExpression::BoundColumnRefExpression(bound_col_ref)) + } else { + println!("current binder context: {:#?}", self.binder.bind_context); + Err(BindError::Internal(format!( + "column not found: {}", + column_name + ))) + } + } +} diff --git a/src/planner_v2/binder/expression/bind_constant_expression.rs b/src/planner_v2/binder/expression/bind_constant_expression.rs new file mode 100644 index 0000000..c0d2d77 --- /dev/null +++ b/src/planner_v2/binder/expression/bind_constant_expression.rs @@ -0,0 +1,28 @@ +use derive_new::new; + +use super::{BoundExpression, BoundExpressionBase}; +use crate::planner_v2::{BindError, ExpressionBinder}; +use crate::types_v2::{LogicalType, ScalarValue}; + +#[derive(new, Debug, Clone)] +pub struct BoundConstantExpression { + pub(crate) base: BoundExpressionBase, + pub(crate) value: ScalarValue, +} + +impl ExpressionBinder<'_> { + pub fn bind_constant_expr( + &self, + v: &sqlparser::ast::Value, + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + let scalar: ScalarValue = v.into(); + let base = BoundExpressionBase::new("".to_string(), scalar.get_logical_type()); + result_names.push(base.alias.clone()); + result_types.push(base.return_type.clone()); + let expr = + BoundExpression::BoundConstantExpression(BoundConstantExpression::new(base, scalar)); + Ok(expr) + } +} diff --git a/src/planner_v2/binder/expression/bind_reference_expression.rs b/src/planner_v2/binder/expression/bind_reference_expression.rs new file mode 100644 index 0000000..9f93502 --- /dev/null +++ b/src/planner_v2/binder/expression/bind_reference_expression.rs @@ -0,0 +1,11 @@ +use derive_new::new; + +use super::BoundExpressionBase; + +/// A BoundReferenceExpression represents a physical index into a DataChunk +#[derive(new, Debug, Clone)] +pub struct BoundReferenceExpression { + pub(crate) base: BoundExpressionBase, + /// Index used to access data in the chunks + pub(crate) index: usize, +} diff --git a/src/planner_v2/binder/expression/column_binding.rs b/src/planner_v2/binder/expression/column_binding.rs new file mode 100644 index 0000000..811b86d --- /dev/null +++ b/src/planner_v2/binder/expression/column_binding.rs @@ -0,0 +1,7 @@ +use derive_new::new; + +#[derive(new, Clone, Debug, PartialEq, Eq, Hash)] +pub struct ColumnBinding { + pub(crate) table_idx: usize, + pub(crate) column_idx: usize, +} diff --git a/src/planner_v2/binder/expression/mod.rs b/src/planner_v2/binder/expression/mod.rs new file mode 100644 index 0000000..19d61b0 --- /dev/null +++ b/src/planner_v2/binder/expression/mod.rs @@ -0,0 +1,45 @@ +mod bind_column_ref_expression; +mod bind_constant_expression; +mod bind_reference_expression; +mod column_binding; + +pub use bind_column_ref_expression::*; +pub use bind_constant_expression::*; +pub use bind_reference_expression::*; +pub use column_binding::*; +use derive_new::new; + +use crate::types_v2::LogicalType; + +/// The Expression represents a bound Expression with a return type +#[derive(new, Debug, Clone)] +pub struct BoundExpressionBase { + /// The alias of the expression, + pub(crate) alias: String, + pub(crate) return_type: LogicalType, +} + +#[derive(Debug, Clone)] +pub enum BoundExpression { + BoundColumnRefExpression(BoundColumnRefExpression), + BoundConstantExpression(BoundConstantExpression), + BoundReferenceExpression(BoundReferenceExpression), +} + +impl BoundExpression { + pub fn return_type(&self) -> LogicalType { + match self { + BoundExpression::BoundColumnRefExpression(expr) => expr.base.return_type.clone(), + BoundExpression::BoundConstantExpression(expr) => expr.base.return_type.clone(), + BoundExpression::BoundReferenceExpression(expr) => expr.base.return_type.clone(), + } + } + + pub fn alias(&self) -> String { + match self { + BoundExpression::BoundColumnRefExpression(expr) => expr.base.alias.clone(), + BoundExpression::BoundConstantExpression(expr) => expr.base.alias.clone(), + BoundExpression::BoundReferenceExpression(expr) => expr.base.alias.clone(), + } + } +} diff --git a/src/planner_v2/binder/mod.rs b/src/planner_v2/binder/mod.rs new file mode 100644 index 0000000..cdd1f99 --- /dev/null +++ b/src/planner_v2/binder/mod.rs @@ -0,0 +1,69 @@ +mod bind_context; +mod binding; +mod errors; +mod expression; +mod query_node; +mod statement; +mod tableref; +mod util; + +use std::sync::Arc; + +pub use bind_context::*; +pub use binding::*; +pub use errors::*; +pub use expression::*; +pub use query_node::*; +pub use statement::*; +pub use tableref::*; +pub use util::*; + +use crate::main_entry::ClientContext; + +#[derive(Clone)] +pub struct Binder { + client_context: Arc, + bind_context: BindContext, + /// The count of bound_tables + bound_tables: usize, + #[allow(dead_code)] + parent: Option>, +} + +impl Binder { + pub fn new(client_context: Arc) -> Self { + Self { + client_context, + bind_context: BindContext::new(), + bound_tables: 0, + parent: None, + } + } + + pub fn new_with_parent(client_context: Arc, parent: Arc) -> Self { + Self { + client_context, + bind_context: BindContext::new(), + bound_tables: 0, + parent: Some(parent), + } + } + + pub fn clone_client_context(&self) -> Arc { + self.client_context.clone() + } + + pub fn generate_table_index(&mut self) -> usize { + self.bound_tables += 1; + self.bound_tables + } + + pub fn has_match_binding(&mut self, table_name: &str, column_name: &str) -> bool { + let binding = self.bind_context.get_binding(table_name); + if binding.is_none() { + return false; + } + let binding = binding.unwrap(); + binding.has_match_binding(column_name) + } +} diff --git a/src/planner_v2/binder/query_node/bind_select_node.rs b/src/planner_v2/binder/query_node/bind_select_node.rs new file mode 100644 index 0000000..9060536 --- /dev/null +++ b/src/planner_v2/binder/query_node/bind_select_node.rs @@ -0,0 +1,100 @@ +use derive_new::new; +use sqlparser::ast::{Ident, Query}; + +use crate::planner_v2::{ + BindError, Binder, BoundExpression, BoundTableRef, ExpressionBinder, VALUES_LIST_ALIAS, +}; +use crate::types_v2::LogicalType; + +#[derive(new, Debug)] +pub struct BoundSelectNode { + /// The names returned by this QueryNode. + pub(crate) names: Vec, + /// The types returned by this QueryNode. + pub(crate) types: Vec, + /// The projection list + pub(crate) select_list: Vec, + /// The FROM clause + pub(crate) from_table: BoundTableRef, + /// Index used by the LogicalProjection + #[new(default)] + pub(crate) projection_index: usize, +} + +impl Binder { + pub fn bind_select_node(&mut self, select_node: &Query) -> Result { + let projection_index = self.generate_table_index(); + let mut bound_select_node = match &*select_node.body { + sqlparser::ast::SetExpr::Select(select) => self.bind_select_body(select)?, + sqlparser::ast::SetExpr::Query(_) => todo!(), + sqlparser::ast::SetExpr::SetOperation { .. } => todo!(), + sqlparser::ast::SetExpr::Values(v) => self.bind_values(v)?, + sqlparser::ast::SetExpr::Insert(_) => todo!(), + }; + bound_select_node.projection_index = projection_index; + Ok(bound_select_node) + } + + pub fn bind_values( + &mut self, + values: &sqlparser::ast::Values, + ) -> Result { + let bound_expression_list_ref = self.bind_expression_list_ref(values)?; + let names = bound_expression_list_ref.names.clone(); + let types = bound_expression_list_ref.types.clone(); + let mut expr_binder = ExpressionBinder::new(self); + let select_list = names + .iter() + .map(|n| { + let idents = vec![ + Ident::new(VALUES_LIST_ALIAS.to_string()), + Ident::new(n.to_string()), + ]; + expr_binder.bind_column_ref_expr(&idents, &mut vec![], &mut vec![]) + }) + .try_collect::>()?; + + let bound_table_ref = BoundTableRef::BoundExpressionListRef(bound_expression_list_ref); + let node = BoundSelectNode::new(names, types, select_list, bound_table_ref); + Ok(node) + } + + pub fn bind_select_body( + &mut self, + select: &sqlparser::ast::Select, + ) -> Result { + let from_table = self.bind_table_ref(select.from.as_slice())?; + + let mut result_names = vec![]; + let mut result_types = vec![]; + let select_list = select + .projection + .iter() + .map(|item| self.bind_select_item(item, &mut result_names, &mut result_types)) + .collect::, _>>()?; + + Ok(BoundSelectNode::new( + result_names, + result_types, + select_list, + from_table, + )) + } + + fn bind_select_item( + &mut self, + item: &sqlparser::ast::SelectItem, + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + let mut expr_binder = ExpressionBinder::new(self); + match item { + sqlparser::ast::SelectItem::UnnamedExpr(expr) => { + expr_binder.bind_expression(expr, result_names, result_types) + } + sqlparser::ast::SelectItem::ExprWithAlias { .. } => todo!(), + sqlparser::ast::SelectItem::Wildcard => todo!(), + sqlparser::ast::SelectItem::QualifiedWildcard(_) => todo!(), + } + } +} diff --git a/src/planner_v2/binder/query_node/mod.rs b/src/planner_v2/binder/query_node/mod.rs new file mode 100644 index 0000000..9a196c2 --- /dev/null +++ b/src/planner_v2/binder/query_node/mod.rs @@ -0,0 +1,4 @@ +mod bind_select_node; +mod plan_select_node; +pub use bind_select_node::*; +pub use plan_select_node::*; diff --git a/src/planner_v2/binder/query_node/plan_select_node.rs b/src/planner_v2/binder/query_node/plan_select_node.rs new file mode 100644 index 0000000..ab803c3 --- /dev/null +++ b/src/planner_v2/binder/query_node/plan_select_node.rs @@ -0,0 +1,26 @@ +use super::BoundSelectNode; +use crate::planner_v2::BoundTableRef::{BoundBaseTableRef, BoundExpressionListRef}; +use crate::planner_v2::{ + BindError, Binder, BoundStatement, LogicalOperator, LogicalOperatorBase, LogicalProjection, +}; + +impl Binder { + pub fn create_plan_for_select_node( + &mut self, + node: BoundSelectNode, + ) -> Result { + let root = match node.from_table { + BoundExpressionListRef(bound_ref) => { + self.create_plan_for_expression_list_ref(bound_ref)? + } + BoundBaseTableRef(bound_ref) => self.create_plan_for_base_tabel_ref(*bound_ref)?, + }; + + let root = LogicalOperator::LogicalProjection(LogicalProjection::new( + LogicalOperatorBase::new(vec![root], node.select_list, node.types.clone()), + node.projection_index, + )); + + Ok(BoundStatement::new(root, node.types, node.names)) + } +} diff --git a/src/planner_v2/binder/statement/bind_create.rs b/src/planner_v2/binder/statement/bind_create.rs new file mode 100644 index 0000000..d066bc8 --- /dev/null +++ b/src/planner_v2/binder/statement/bind_create.rs @@ -0,0 +1,45 @@ +use sqlparser::ast::Statement; + +use super::{BoundStatement, CreateTableInfo}; +use crate::catalog_v2::ColumnDefinition; +use crate::planner_v2::{ + BindError, Binder, CreateInfoBase, LogicalCreateTable, LogicalOperator, SqlparserResolver, +}; +use crate::types_v2::LogicalType; + +impl Binder { + pub fn bind_create_table(&self, stmt: &Statement) -> Result { + match stmt { + Statement::CreateTable { name, columns, .. } => { + let (schema, table) = SqlparserResolver::object_name_to_schema_table(name)?; + let column_definitions = columns + .iter() + .map(SqlparserResolver::column_def_to_column_definition) + .try_collect()?; + let bound_info = BoundCreateTableInfo::new(schema, table, column_definitions); + let plan = LogicalOperator::LogicalCreateTable(LogicalCreateTable::new(bound_info)); + Ok(BoundStatement::new( + plan, + vec![LogicalType::Varchar], + vec!["success".to_string()], + )) + } + _ => Err(BindError::UnsupportedStmt(format!("{:?}", stmt))), + } + } +} + +#[derive(Debug, Clone)] +pub struct BoundCreateTableInfo { + pub(crate) base: CreateTableInfo, +} + +impl BoundCreateTableInfo { + pub fn new(schema: String, table: String, column_definitions: Vec) -> Self { + let base = CreateInfoBase::new(schema); + let create_table_info = CreateTableInfo::new(base, table, column_definitions); + Self { + base: create_table_info, + } + } +} diff --git a/src/planner_v2/binder/statement/bind_insert.rs b/src/planner_v2/binder/statement/bind_insert.rs new file mode 100644 index 0000000..854a972 --- /dev/null +++ b/src/planner_v2/binder/statement/bind_insert.rs @@ -0,0 +1,103 @@ +use std::collections::HashMap; + +use sqlparser::ast::Statement; + +use super::BoundStatement; +use crate::catalog_v2::Catalog; +use crate::planner_v2::{ + BindError, Binder, BoundTableRef, LogicalInsert, LogicalOperator, LogicalOperatorBase, + SqlparserResolver, INVALID_INDEX, +}; +use crate::types_v2::LogicalType; + +impl Binder { + pub fn bind_insert(&mut self, stmt: &Statement) -> Result { + match stmt { + Statement::Insert { + table_name, + columns, + source, + .. + } => { + let (schema_name, table_name) = + SqlparserResolver::object_name_to_schema_table(table_name)?; + let table = Catalog::get_table( + self.clone_client_context(), + schema_name, + table_name.clone(), + )?; + + // insert column mapped to table column type + let mut expected_types = vec![]; + // insert column mapped to table column index + let mut named_column_indices = vec![]; + // The insertion map ([table_index -> index in result, or DConstants::INVALID_INDEX + // if not specified]) + let mut column_index_list = vec![]; + if columns.is_empty() { + for (idx, col) in table.columns.iter().enumerate() { + named_column_indices.push(idx); + column_index_list.push(idx); + expected_types.push(col.ty.clone()); + } + } else { + // insertion statement specifies column list + // column_name to insert columns index + let mut column_name_2_insert_idx_map = HashMap::new(); + for (idx, col) in columns.iter().enumerate() { + column_name_2_insert_idx_map.insert(col.value.clone(), idx); + let column_index = match table.name_map.get(col.value.as_str()) { + Some(e) => e, + None => { + return Err(BindError::Internal(format!( + "column {} not found in table {}", + col.value, table_name + ))) + } + }; + expected_types.push(table.columns[*column_index].ty.clone()); + named_column_indices.push(*column_index); + } + for col in table.columns.iter() { + let insert_column_index = + match column_name_2_insert_idx_map.get(col.name.as_str()) { + Some(i) => *i, + None => INVALID_INDEX, + }; + column_index_list.push(insert_column_index); + } + } + + let select_node = self.bind_select_node(source)?; + let expected_columns_cnt = named_column_indices.len(); + if let BoundTableRef::BoundExpressionListRef(table_ref) = &select_node.from_table { + // CheckInsertColumnCountMismatch + let insert_columns_cnt = table_ref.values.first().unwrap().len(); + if expected_columns_cnt != insert_columns_cnt { + return Err(BindError::Internal(format!( + "insert column count mismatch, expected: {}, actual: {}", + expected_columns_cnt, insert_columns_cnt + ))); + } + }; + + // TODO: cast types + + let select_node = self.create_plan_for_select_node(select_node)?; + let plan = select_node.plan; + let root = LogicalInsert::new( + LogicalOperatorBase::new(vec![plan], vec![], vec![]), + column_index_list, + expected_types, + table, + ); + Ok(BoundStatement::new( + LogicalOperator::LogicalInsert(root), + vec![LogicalType::Varchar], + vec!["success".to_string()], + )) + } + _ => Err(BindError::UnsupportedStmt(format!("{:?}", stmt))), + } + } +} diff --git a/src/planner_v2/binder/statement/bind_select.rs b/src/planner_v2/binder/statement/bind_select.rs new file mode 100644 index 0000000..0d53ddd --- /dev/null +++ b/src/planner_v2/binder/statement/bind_select.rs @@ -0,0 +1,16 @@ +use sqlparser::ast::Statement; + +use super::BoundStatement; +use crate::planner_v2::{BindError, Binder}; + +impl Binder { + pub fn bind_select(&mut self, stmt: &Statement) -> Result { + match stmt { + Statement::Query(query) => { + let node = self.bind_select_node(query)?; + self.create_plan_for_select_node(node) + } + _ => Err(BindError::UnsupportedStmt(format!("{:?}", stmt))), + } + } +} diff --git a/src/planner_v2/binder/statement/create_info.rs b/src/planner_v2/binder/statement/create_info.rs new file mode 100644 index 0000000..ebb85d5 --- /dev/null +++ b/src/planner_v2/binder/statement/create_info.rs @@ -0,0 +1,17 @@ +use derive_new::new; + +use crate::catalog_v2::ColumnDefinition; + +#[derive(new, Debug, Clone)] +pub struct CreateTableInfo { + pub(crate) base: CreateInfoBase, + /// Table name to insert to + pub(crate) table: String, + /// List of columns of the table + pub(crate) columns: Vec, +} + +#[derive(new, Debug, Clone)] +pub struct CreateInfoBase { + pub(crate) schema: String, +} diff --git a/src/planner_v2/binder/statement/mod.rs b/src/planner_v2/binder/statement/mod.rs new file mode 100644 index 0000000..9d759dc --- /dev/null +++ b/src/planner_v2/binder/statement/mod.rs @@ -0,0 +1,32 @@ +mod bind_create; +mod bind_insert; +mod bind_select; +mod create_info; +pub use bind_create::*; +pub use bind_insert::*; +pub use bind_select::*; +pub use create_info::*; +use derive_new::new; +use sqlparser::ast::Statement; + +use super::{BindError, Binder}; +use crate::planner_v2::LogicalOperator; +use crate::types_v2::LogicalType; + +#[derive(new, Debug)] +pub struct BoundStatement { + pub(crate) plan: LogicalOperator, + pub(crate) types: Vec, + pub(crate) names: Vec, +} + +impl Binder { + pub fn bind(&mut self, statement: &Statement) -> Result { + match statement { + Statement::CreateTable { .. } => self.bind_create_table(statement), + Statement::Insert { .. } => self.bind_insert(statement), + Statement::Query { .. } => self.bind_select(statement), + _ => Err(BindError::UnsupportedStmt(format!("{:?}", statement))), + } + } +} diff --git a/src/planner_v2/binder/tableref/bind_base_table_ref.rs b/src/planner_v2/binder/tableref/bind_base_table_ref.rs new file mode 100644 index 0000000..40882c9 --- /dev/null +++ b/src/planner_v2/binder/tableref/bind_base_table_ref.rs @@ -0,0 +1,62 @@ +use derive_new::new; + +use super::BoundTableRef; +use crate::catalog_v2::{Catalog, CatalogEntry, TableCatalogEntry}; +use crate::planner_v2::{ + BindError, Binder, LogicalGet, LogicalOperator, LogicalOperatorBase, SqlparserResolver, +}; + +/// Represents a TableReference to a base table in the schema +#[derive(new, Debug)] +pub struct BoundBaseTableRef { + #[allow(dead_code)] + pub(crate) table: TableCatalogEntry, + pub(crate) get: LogicalOperator, +} + +impl Binder { + pub fn bind_base_table_ref( + &mut self, + table: sqlparser::ast::TableFactor, + ) -> Result { + match table { + sqlparser::ast::TableFactor::Table { name, alias, .. } => { + let table_index = self.generate_table_index(); + let (schema, table) = SqlparserResolver::object_name_to_schema_table(&name)?; + let alias = alias + .map(|a| a.to_string()) + .unwrap_or_else(|| table.clone()); + let table = Catalog::get_table(self.clone_client_context(), schema, table)?; + + let mut return_names = vec![]; + let mut return_types = vec![]; + for col in table.columns.iter() { + return_names.push(col.name.clone()); + return_types.push(col.ty.clone()); + } + let logical_get = LogicalGet::new( + LogicalOperatorBase::default(), + table_index, + table.clone(), + return_types.clone(), + return_names.clone(), + ); + let get = LogicalOperator::LogicalGet(logical_get); + self.bind_context.add_base_table( + alias, + table_index, + return_types, + return_names, + CatalogEntry::TableCatalogEntry(table.clone()), + ); + let bound_tabel_ref = + BoundTableRef::BoundBaseTableRef(Box::new(BoundBaseTableRef::new(table, get))); + Ok(bound_tabel_ref) + } + other => Err(BindError::Internal(format!( + "unexpected table type: {}, only bind TableFactor::Table", + other + ))), + } + } +} diff --git a/src/planner_v2/binder/tableref/bind_expression_list_ref.rs b/src/planner_v2/binder/tableref/bind_expression_list_ref.rs new file mode 100644 index 0000000..ee89bd4 --- /dev/null +++ b/src/planner_v2/binder/tableref/bind_expression_list_ref.rs @@ -0,0 +1,57 @@ +use derive_new::new; +use sqlparser::ast::Values; + +use crate::planner_v2::{BindError, Binder, BoundExpression, ExpressionBinder}; +use crate::types_v2::LogicalType; + +pub static VALUES_LIST_ALIAS: &str = "valueslist"; + +/// Represents a TableReference to a base table in the schema +#[derive(new, Debug)] +pub struct BoundExpressionListRef { + /// The bound VALUES list + pub(crate) values: Vec>, + /// The generated names of the values list + pub(crate) names: Vec, + /// The types of the values list + pub(crate) types: Vec, + /// The index in the bind context + pub(crate) bind_index: usize, +} + +impl Binder { + pub fn bind_expression_list_ref( + &mut self, + values: &Values, + ) -> Result { + let mut bound_expr_list = vec![]; + let mut names = vec![]; + let mut types = vec![]; + let mut finish_name = false; + + let mut expr_binder = ExpressionBinder::new(self); + + for val_expr_list in values.0.iter() { + let mut bound_expr_row = vec![]; + for (idx, expr) in val_expr_list.iter().enumerate() { + let bound_expr = expr_binder.bind_expression(expr, &mut vec![], &mut vec![])?; + if !finish_name { + names.push(format!("col{}", idx)); + types.push(bound_expr.return_type()); + } + bound_expr_row.push(bound_expr); + } + bound_expr_list.push(bound_expr_row); + finish_name = true; + } + let table_index = self.generate_table_index(); + self.bind_context.add_generic_binding( + VALUES_LIST_ALIAS.to_string(), + table_index, + types.clone(), + names.clone(), + ); + let bound_ref = BoundExpressionListRef::new(bound_expr_list, names, types, table_index); + Ok(bound_ref) + } +} diff --git a/src/planner_v2/binder/tableref/mod.rs b/src/planner_v2/binder/tableref/mod.rs new file mode 100644 index 0000000..620acd9 --- /dev/null +++ b/src/planner_v2/binder/tableref/mod.rs @@ -0,0 +1,34 @@ +mod bind_base_table_ref; +mod bind_expression_list_ref; +mod plan_base_table_ref; +mod plan_expression_list_ref; +pub use bind_base_table_ref::*; +pub use bind_expression_list_ref::*; +pub use plan_base_table_ref::*; +pub use plan_expression_list_ref::*; + +use super::{BindError, Binder}; + +#[derive(Debug)] +pub enum BoundTableRef { + BoundExpressionListRef(BoundExpressionListRef), + BoundBaseTableRef(Box), +} + +impl Binder { + pub fn bind_table_ref( + &mut self, + table_refs: &[sqlparser::ast::TableWithJoins], + ) -> Result { + let first_table = table_refs[0].clone(); + match first_table.relation.clone() { + sqlparser::ast::TableFactor::Table { .. } => { + self.bind_base_table_ref(first_table.relation) + } + other => Err(BindError::Internal(format!( + "unexpected table type: {}", + other + ))), + } + } +} diff --git a/src/planner_v2/binder/tableref/plan_base_table_ref.rs b/src/planner_v2/binder/tableref/plan_base_table_ref.rs new file mode 100644 index 0000000..f506f0b --- /dev/null +++ b/src/planner_v2/binder/tableref/plan_base_table_ref.rs @@ -0,0 +1,11 @@ +use super::BoundBaseTableRef; +use crate::planner_v2::{BindError, Binder, LogicalOperator}; + +impl Binder { + pub fn create_plan_for_base_tabel_ref( + &mut self, + bound_ref: BoundBaseTableRef, + ) -> Result { + Ok(bound_ref.get) + } +} diff --git a/src/planner_v2/binder/tableref/plan_expression_list_ref.rs b/src/planner_v2/binder/tableref/plan_expression_list_ref.rs new file mode 100644 index 0000000..d3cb934 --- /dev/null +++ b/src/planner_v2/binder/tableref/plan_expression_list_ref.rs @@ -0,0 +1,16 @@ +use super::BoundExpressionListRef; +use crate::planner_v2::{ + BindError, Binder, LogicalExpressionGet, LogicalOperator, LogicalOperatorBase, +}; + +impl Binder { + pub fn create_plan_for_expression_list_ref( + &mut self, + bound_ref: BoundExpressionListRef, + ) -> Result { + let table_idx = bound_ref.bind_index; + let base = LogicalOperatorBase::default(); + let plan = LogicalExpressionGet::new(base, table_idx, bound_ref.types, bound_ref.values); + Ok(LogicalOperator::LogicalExpressionGet(plan)) + } +} diff --git a/src/planner_v2/binder/util.rs b/src/planner_v2/binder/util.rs new file mode 100644 index 0000000..5e59223 --- /dev/null +++ b/src/planner_v2/binder/util.rs @@ -0,0 +1,29 @@ +use sqlparser::ast::{ColumnDef, ObjectName}; + +use super::BindError; +use crate::catalog_v2::{ColumnDefinition, DEFAULT_SCHEMA}; + +pub struct SqlparserResolver {} + +impl SqlparserResolver { + /// Resolve object_name which is a name of a table, view, custom type, etc., possibly + /// multi-part, i.e. db.schema.obj + pub fn object_name_to_schema_table( + object_name: &ObjectName, + ) -> Result<(String, String), BindError> { + let (schema, table) = match object_name.0.as_slice() { + [table] => (DEFAULT_SCHEMA.to_string(), table.value.clone()), + [schema, table] => (schema.value.clone(), table.value.clone()), + _ => return Err(BindError::SqlParserUnsupportedStmt(object_name.to_string())), + }; + Ok((schema, table)) + } + + pub fn column_def_to_column_definition( + column_def: &ColumnDef, + ) -> Result { + let name = column_def.name.value.clone(); + let ty = column_def.data_type.clone().try_into()?; + Ok(ColumnDefinition::new(name, ty)) + } +} diff --git a/src/planner_v2/constants.rs b/src/planner_v2/constants.rs new file mode 100644 index 0000000..2036696 --- /dev/null +++ b/src/planner_v2/constants.rs @@ -0,0 +1 @@ +pub static INVALID_INDEX: usize = std::usize::MAX; diff --git a/src/planner_v2/errors.rs b/src/planner_v2/errors.rs new file mode 100644 index 0000000..f7feab1 --- /dev/null +++ b/src/planner_v2/errors.rs @@ -0,0 +1,11 @@ +use super::BindError; + +#[derive(thiserror::Error, Debug)] +pub enum PlannerError { + #[error("bind error: {0}")] + BindError( + #[from] + #[source] + BindError, + ), +} diff --git a/src/planner_v2/expression_binder.rs b/src/planner_v2/expression_binder.rs new file mode 100644 index 0000000..4fef4e8 --- /dev/null +++ b/src/planner_v2/expression_binder.rs @@ -0,0 +1,38 @@ +use std::slice; + +use derive_new::new; + +use super::{BindError, Binder, BoundExpression}; +use crate::types_v2::LogicalType; + +#[derive(new)] +pub struct ExpressionBinder<'a> { + pub(crate) binder: &'a mut Binder, +} + +impl ExpressionBinder<'_> { + pub fn bind_expression( + &mut self, + expr: &sqlparser::ast::Expr, + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + match expr { + sqlparser::ast::Expr::Identifier(ident) => { + self.bind_column_ref_expr(slice::from_ref(ident), result_names, result_types) + } + sqlparser::ast::Expr::CompoundIdentifier(idents) => { + self.bind_column_ref_expr(idents, result_names, result_types) + } + sqlparser::ast::Expr::BinaryOp { .. } => todo!(), + sqlparser::ast::Expr::UnaryOp { .. } => todo!(), + sqlparser::ast::Expr::Value(v) => { + self.bind_constant_expr(v, result_names, result_types) + } + sqlparser::ast::Expr::Function(_) => todo!(), + sqlparser::ast::Expr::Exists { .. } => todo!(), + sqlparser::ast::Expr::Subquery(_) => todo!(), + _ => todo!(), + } + } +} diff --git a/src/planner_v2/expression_iterator.rs b/src/planner_v2/expression_iterator.rs new file mode 100644 index 0000000..182fbfd --- /dev/null +++ b/src/planner_v2/expression_iterator.rs @@ -0,0 +1,18 @@ +use super::BoundExpression; + +pub struct ExpressionIterator; + +impl ExpressionIterator { + pub fn enumerate_children(expr: &mut BoundExpression, _callback: F) + where + F: Fn(&mut BoundExpression), + { + match expr { + BoundExpression::BoundColumnRefExpression(_) + | BoundExpression::BoundConstantExpression(_) + | BoundExpression::BoundReferenceExpression(_) => { + // these node types have no children + } + } + } +} diff --git a/src/planner_v2/logical_operator_visitor.rs b/src/planner_v2/logical_operator_visitor.rs new file mode 100644 index 0000000..04753fe --- /dev/null +++ b/src/planner_v2/logical_operator_visitor.rs @@ -0,0 +1,58 @@ +use super::{ + BoundColumnRefExpression, BoundConstantExpression, BoundExpression, BoundReferenceExpression, + ExpressionIterator, LogicalOperator, +}; + +/// Visitor pattern on logical operators, also includes rewrite expression ability. +pub trait LogicalOperatorVisitor { + fn visit_operator(&mut self, op: &mut LogicalOperator) { + self.visit_operator_children(op); + self.visit_operator_expressions(op); + } + + fn visit_operator_children(&mut self, op: &mut LogicalOperator) { + for child in op.children() { + self.visit_operator(child); + } + } + + fn visit_operator_expressions(&mut self, op: &mut LogicalOperator) { + Self::eumerate_expressions(op, |e| self.visit_expression(e)) + } + + fn eumerate_expressions(op: &mut LogicalOperator, callback: F) + where + F: Fn(&mut BoundExpression), + { + for expr in op.expressions() { + callback(expr); + } + } + + fn visit_expression(&self, expr: &mut BoundExpression) { + let result = match expr { + BoundExpression::BoundColumnRefExpression(e) => self.visit_replace_column_ref(e), + BoundExpression::BoundConstantExpression(e) => self.visit_replace_constant(e), + BoundExpression::BoundReferenceExpression(e) => self.visit_replace_reference(e), + }; + if let Some(new_expr) = result { + *expr = new_expr; + } else { + self.visit_expression_children(expr); + } + } + + fn visit_expression_children(&self, expr: &mut BoundExpression) { + ExpressionIterator::enumerate_children(expr, |e| self.visit_expression(e)) + } + + fn visit_replace_column_ref(&self, _: &BoundColumnRefExpression) -> Option { + None + } + fn visit_replace_constant(&self, _: &BoundConstantExpression) -> Option { + None + } + fn visit_replace_reference(&self, _: &BoundReferenceExpression) -> Option { + None + } +} diff --git a/src/planner_v2/mod.rs b/src/planner_v2/mod.rs new file mode 100644 index 0000000..c95cd55 --- /dev/null +++ b/src/planner_v2/mod.rs @@ -0,0 +1,54 @@ +mod binder; +mod constants; +mod errors; +mod expression_binder; +mod expression_iterator; +mod logical_operator_visitor; +mod operator; + +use std::sync::Arc; + +pub use binder::*; +pub use constants::*; +pub use errors::*; +pub use expression_binder::*; +pub use expression_iterator::*; +pub use logical_operator_visitor::*; +pub use operator::*; +use sqlparser::ast::Statement; + +use crate::main_entry::ClientContext; +use crate::types_v2::LogicalType; + +pub struct Planner { + binder: Binder, + #[allow(dead_code)] + client_context: Arc, + pub(crate) plan: Option, + pub(crate) types: Option>, + pub(crate) names: Option>, +} + +impl Planner { + pub fn new(client_context: Arc) -> Self { + Self { + binder: Binder::new(client_context.clone()), + client_context, + plan: None, + types: None, + names: None, + } + } + + pub fn create_plan(&mut self, statement: &Statement) -> Result<(), PlannerError> { + let bound_statement = self.binder.bind(statement)?; + self.plan = Some(bound_statement.plan); + self.names = Some(bound_statement.names); + self.types = Some(bound_statement.types); + // println!( + // "created_plan: {:#?}\nnames: {:?}\ntypes: {:?}", + // self.plan, self.names, self.types + // ); + Ok(()) + } +} diff --git a/src/planner_v2/operator/logical_create_table.rs b/src/planner_v2/operator/logical_create_table.rs new file mode 100644 index 0000000..83eba6a --- /dev/null +++ b/src/planner_v2/operator/logical_create_table.rs @@ -0,0 +1,11 @@ +use derive_new::new; + +use super::LogicalOperatorBase; +use crate::planner_v2::BoundCreateTableInfo; + +#[derive(new, Debug)] +pub struct LogicalCreateTable { + #[new(default)] + pub(crate) base: LogicalOperatorBase, + pub(crate) info: BoundCreateTableInfo, +} diff --git a/src/planner_v2/operator/logical_expression_get.rs b/src/planner_v2/operator/logical_expression_get.rs new file mode 100644 index 0000000..ad4a5b3 --- /dev/null +++ b/src/planner_v2/operator/logical_expression_get.rs @@ -0,0 +1,16 @@ +use derive_new::new; + +use super::LogicalOperatorBase; +use crate::planner_v2::BoundExpression; +use crate::types_v2::LogicalType; + +/// LogicalExpressionGet represents a scan operation over a set of to-be-executed expressions +#[derive(new, Debug)] +pub struct LogicalExpressionGet { + pub(crate) base: LogicalOperatorBase, + pub(crate) table_idx: usize, + /// The types of the expressions + pub(crate) expr_types: Vec, + /// The set of expressions + pub(crate) expressions: Vec>, +} diff --git a/src/planner_v2/operator/logical_get.rs b/src/planner_v2/operator/logical_get.rs new file mode 100644 index 0000000..3aea6f0 --- /dev/null +++ b/src/planner_v2/operator/logical_get.rs @@ -0,0 +1,22 @@ +use derive_new::new; + +use super::LogicalOperatorBase; +use crate::catalog_v2::TableCatalogEntry; +use crate::types_v2::LogicalType; + +/// LogicalGet represents a scan operation from a data source +#[derive(new, Debug)] +pub struct LogicalGet { + pub(crate) base: LogicalOperatorBase, + pub(crate) table_idx: usize, + // TODO: migrate to FunctionData when support TableFunction + pub(crate) bind_table: TableCatalogEntry, + /// The types of ALL columns that can be returned by the table function + pub(crate) returned_types: Vec, + /// The names of ALL columns that can be returned by the table function + pub(crate) names: Vec, + /// Bound column IDs + #[new(default)] + #[allow(dead_code)] + pub(crate) column_ids: Vec, +} diff --git a/src/planner_v2/operator/logical_insert.rs b/src/planner_v2/operator/logical_insert.rs new file mode 100644 index 0000000..cc1990d --- /dev/null +++ b/src/planner_v2/operator/logical_insert.rs @@ -0,0 +1,15 @@ +use derive_new::new; + +use super::LogicalOperatorBase; +use crate::catalog_v2::TableCatalogEntry; +use crate::types_v2::LogicalType; + +#[derive(new, Debug)] +pub struct LogicalInsert { + pub(crate) base: LogicalOperatorBase, + /// The insertion map ([table_index -> index in result, or INVALID_INDEX if not specified]) + pub(crate) column_index_list: Vec, + /// The expected types for the INSERT statement + pub(crate) expected_types: Vec, + pub(crate) table: TableCatalogEntry, +} diff --git a/src/planner_v2/operator/logical_projection.rs b/src/planner_v2/operator/logical_projection.rs new file mode 100644 index 0000000..482e097 --- /dev/null +++ b/src/planner_v2/operator/logical_projection.rs @@ -0,0 +1,9 @@ +use derive_new::new; + +use super::LogicalOperatorBase; + +#[derive(new, Debug)] +pub struct LogicalProjection { + pub(crate) base: LogicalOperatorBase, + pub(crate) table_idx: usize, +} diff --git a/src/planner_v2/operator/mod.rs b/src/planner_v2/operator/mod.rs new file mode 100644 index 0000000..01c5463 --- /dev/null +++ b/src/planner_v2/operator/mod.rs @@ -0,0 +1,109 @@ +use crate::types_v2::LogicalType; + +mod logical_create_table; +mod logical_expression_get; +mod logical_get; +mod logical_insert; +mod logical_projection; +use derive_new::new; +pub use logical_create_table::*; +pub use logical_expression_get::*; +pub use logical_get::*; +pub use logical_insert::*; +pub use logical_projection::*; + +use super::{BoundExpression, ColumnBinding}; + +#[derive(new, Default, Debug)] +pub struct LogicalOperatorBase { + pub(crate) children: Vec, + // The set of expressions contained within the operator, if any + pub(crate) expressioins: Vec, + /// The types returned by this logical operator. + pub(crate) types: Vec, +} + +#[derive(Debug)] +pub enum LogicalOperator { + LogicalCreateTable(LogicalCreateTable), + LogicalExpressionGet(LogicalExpressionGet), + LogicalInsert(LogicalInsert), + LogicalGet(LogicalGet), + LogicalProjection(LogicalProjection), +} + +impl LogicalOperator { + pub fn children(&mut self) -> &mut [LogicalOperator] { + match self { + LogicalOperator::LogicalCreateTable(op) => &mut op.base.children, + LogicalOperator::LogicalExpressionGet(op) => &mut op.base.children, + LogicalOperator::LogicalInsert(op) => &mut op.base.children, + LogicalOperator::LogicalGet(op) => &mut op.base.children, + LogicalOperator::LogicalProjection(op) => &mut op.base.children, + } + } + + pub fn expressions(&mut self) -> &mut [BoundExpression] { + match self { + LogicalOperator::LogicalCreateTable(op) => &mut op.base.expressioins, + LogicalOperator::LogicalExpressionGet(op) => &mut op.base.expressioins, + LogicalOperator::LogicalInsert(op) => &mut op.base.expressioins, + LogicalOperator::LogicalGet(op) => &mut op.base.expressioins, + LogicalOperator::LogicalProjection(op) => &mut op.base.expressioins, + } + } + + pub fn get_column_bindings(&self) -> Vec { + let default = vec![ColumnBinding::new(0, 0)]; + match self { + LogicalOperator::LogicalCreateTable(_) => default, + LogicalOperator::LogicalExpressionGet(op) => { + self.generate_column_bindings(op.table_idx, op.expr_types.len()) + } + LogicalOperator::LogicalInsert(_) => default, + LogicalOperator::LogicalGet(op) => { + self.generate_column_bindings(op.table_idx, op.returned_types.len()) + } + LogicalOperator::LogicalProjection(op) => { + self.generate_column_bindings(op.table_idx, op.base.expressioins.len()) + } + } + } + + pub fn resolve_operator_types(&mut self) { + for child in self.children() { + child.resolve_operator_types(); + } + match self { + LogicalOperator::LogicalCreateTable(op) => { + op.base.types.push(LogicalType::Bigint); + } + LogicalOperator::LogicalExpressionGet(op) => { + op.base.types = op.expr_types.clone(); + } + LogicalOperator::LogicalInsert(op) => op.base.types.push(LogicalType::Bigint), + LogicalOperator::LogicalGet(op) => op.base.types.extend(op.returned_types.clone()), + LogicalOperator::LogicalProjection(op) => { + let types = op + .base + .expressioins + .iter() + .map(|e| e.return_type()) + .collect::>(); + op.base.types.extend(types); + } + } + } + + fn generate_column_bindings( + &self, + table_idx: usize, + column_count: usize, + ) -> Vec { + let mut result = vec![]; + for idx in 0..column_count { + result.push(ColumnBinding::new(table_idx, idx)) + } + result + } +} diff --git a/src/storage_v2/local_storage.rs b/src/storage_v2/local_storage.rs new file mode 100644 index 0000000..c5a2774 --- /dev/null +++ b/src/storage_v2/local_storage.rs @@ -0,0 +1,99 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; +use derive_new::new; + +use crate::catalog_v2::DataTable; +use crate::main_entry::ClientContext; + +/// Used as in-memory storage +#[derive(Default)] +pub struct LocalStorage { + table_manager: LocalTableManager, +} + +impl LocalStorage { + fn append_internal(&mut self, table: &DataTable, batch: RecordBatch) { + self.table_manager.init_storage(table); + self.table_manager.append(table, batch); + } + + pub fn append(client_context: Arc, table: &DataTable, batch: RecordBatch) { + let mut storage = client_context.db.storage.try_write().unwrap(); + storage.append_internal(table, batch); + } + + pub fn create_reader(table: &DataTable) -> LocalStorageReader { + LocalStorageReader::new(table.clone()) + } +} + +#[derive(new)] +pub struct LocalStorageReader { + table: DataTable, + #[new(default)] + current_batch_cursor: usize, +} + +impl LocalStorageReader { + pub fn next_batch(&mut self, client_context: Arc) -> Option { + let storage = client_context.db.storage.try_read().unwrap(); + let batch = storage + .table_manager + .fetch_table_batch(&self.table, self.current_batch_cursor); + self.current_batch_cursor += 1; + batch + } +} + +#[derive(Default)] +pub struct LocalTableManager { + table_storage: HashMap, +} + +impl LocalTableManager { + pub fn init_storage(&mut self, table: &DataTable) { + if !self.table_storage.contains_key(table) { + let storage = LocalTableStorage::new(table.clone()); + self.table_storage.insert(table.clone(), storage); + } + } + + fn append(&mut self, table: &DataTable, batch: RecordBatch) { + self.table_storage.get_mut(table).unwrap().append(batch); + } + + pub fn fetch_table_batch(&self, table: &DataTable, batch_idx: usize) -> Option { + self.table_storage + .get(table) + .unwrap() + .fetch_batch(batch_idx) + } +} + +pub struct LocalTableStorage { + _table: DataTable, + data: Vec, +} + +impl LocalTableStorage { + pub fn new(table: DataTable) -> Self { + Self { + _table: table, + data: vec![], + } + } + + fn append(&mut self, batch: RecordBatch) { + self.data.push(batch); + } + + fn fetch_batch(&self, batch_idx: usize) -> Option { + if batch_idx >= self.data.len() { + None + } else { + Some(self.data[batch_idx].clone()) + } + } +} diff --git a/src/storage_v2/mod.rs b/src/storage_v2/mod.rs new file mode 100644 index 0000000..6af72bc --- /dev/null +++ b/src/storage_v2/mod.rs @@ -0,0 +1,2 @@ +mod local_storage; +pub use local_storage::*; diff --git a/src/types_v2/errors.rs b/src/types_v2/errors.rs new file mode 100644 index 0000000..babe8ee --- /dev/null +++ b/src/types_v2/errors.rs @@ -0,0 +1,9 @@ +#[derive(thiserror::Error, Debug)] +pub enum TypeError { + #[error("invalid logical type")] + InvalidLogicalType, + #[error("not implemented arrow datatype: {0}")] + NotImplementedArrowDataType(String), + #[error("not implemented sqlparser datatype: {0}")] + NotImplementedSqlparserDataType(String), +} diff --git a/src/types_v2/mod.rs b/src/types_v2/mod.rs new file mode 100644 index 0000000..25b50bf --- /dev/null +++ b/src/types_v2/mod.rs @@ -0,0 +1,6 @@ +mod errors; +mod types; +mod values; +pub use errors::*; +pub use types::*; +pub use values::*; diff --git a/src/types_v2/types.rs b/src/types_v2/types.rs new file mode 100644 index 0000000..24f6178 --- /dev/null +++ b/src/types_v2/types.rs @@ -0,0 +1,73 @@ +use super::TypeError; + +/// Sqlrs type conversion: +/// sqlparser::ast::DataType -> LogicalType -> arrow::datatypes::DataType +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum LogicalType { + Invalid, + Boolean, + Tinyint, + UTinyint, + Smallint, + USmallint, + Integer, + UInteger, + Bigint, + UBigint, + Float, + Double, + Varchar, +} + +/// sqlparser datatype to logical type +impl TryFrom for LogicalType { + type Error = TypeError; + + fn try_from(value: sqlparser::ast::DataType) -> Result { + match value { + sqlparser::ast::DataType::Char(_) + | sqlparser::ast::DataType::Varchar(_) + | sqlparser::ast::DataType::Nvarchar(_) + | sqlparser::ast::DataType::Text + | sqlparser::ast::DataType::String => Ok(LogicalType::Varchar), + sqlparser::ast::DataType::Float(_) => Ok(LogicalType::Float), + sqlparser::ast::DataType::Double => Ok(LogicalType::Double), + sqlparser::ast::DataType::TinyInt(_) => Ok(LogicalType::Tinyint), + sqlparser::ast::DataType::UnsignedTinyInt(_) => Ok(LogicalType::UTinyint), + sqlparser::ast::DataType::SmallInt(_) => Ok(LogicalType::Smallint), + sqlparser::ast::DataType::UnsignedSmallInt(_) => Ok(LogicalType::USmallint), + sqlparser::ast::DataType::Int(_) | sqlparser::ast::DataType::Integer(_) => { + Ok(LogicalType::Integer) + } + sqlparser::ast::DataType::UnsignedInt(_) + | sqlparser::ast::DataType::UnsignedInteger(_) => Ok(LogicalType::UInteger), + sqlparser::ast::DataType::BigInt(_) => Ok(LogicalType::Bigint), + sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(LogicalType::UBigint), + sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean), + other => Err(TypeError::NotImplementedSqlparserDataType( + other.to_string(), + )), + } + } +} + +impl From for arrow::datatypes::DataType { + fn from(value: LogicalType) -> Self { + use arrow::datatypes::DataType; + match value { + LogicalType::Invalid => panic!("invalid logical type"), + LogicalType::Boolean => DataType::Boolean, + LogicalType::Tinyint => DataType::Int8, + LogicalType::UTinyint => DataType::UInt8, + LogicalType::Smallint => DataType::Int16, + LogicalType::USmallint => DataType::UInt16, + LogicalType::Integer => DataType::Int32, + LogicalType::UInteger => DataType::UInt32, + LogicalType::Bigint => DataType::Int64, + LogicalType::UBigint => DataType::UInt64, + LogicalType::Float => DataType::Float32, + LogicalType::Double => DataType::Float64, + LogicalType::Varchar => DataType::Utf8, + } + } +} diff --git a/src/types_v2/values.rs b/src/types_v2/values.rs new file mode 100644 index 0000000..ab12d89 --- /dev/null +++ b/src/types_v2/values.rs @@ -0,0 +1,401 @@ +use std::cmp::Ordering; +use std::fmt; +use std::hash::Hash; +use std::iter::repeat; +use std::sync::Arc; + +use arrow::array::{ + new_null_array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::datatypes::DataType; +use ordered_float::OrderedFloat; + +use super::{LogicalType, TypeError}; + +#[derive(Clone)] +pub enum ScalarValue { + Null, + Boolean(Option), + Float32(Option), + Float64(Option), + Int8(Option), + Int16(Option), + Int32(Option), + Int64(Option), + UInt8(Option), + UInt16(Option), + UInt32(Option), + UInt64(Option), + Utf8(Option), +} + +impl PartialEq for ScalarValue { + fn eq(&self, other: &Self) -> bool { + use ScalarValue::*; + match (self, other) { + (Boolean(v1), Boolean(v2)) => v1.eq(v2), + (Boolean(_), _) => false, + (Float32(v1), Float32(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.eq(&v2) + } + (Float32(_), _) => false, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.eq(&v2) + } + (Float64(_), _) => false, + (Int8(v1), Int8(v2)) => v1.eq(v2), + (Int8(_), _) => false, + (Int16(v1), Int16(v2)) => v1.eq(v2), + (Int16(_), _) => false, + (Int32(v1), Int32(v2)) => v1.eq(v2), + (Int32(_), _) => false, + (Int64(v1), Int64(v2)) => v1.eq(v2), + (Int64(_), _) => false, + (UInt8(v1), UInt8(v2)) => v1.eq(v2), + (UInt8(_), _) => false, + (UInt16(v1), UInt16(v2)) => v1.eq(v2), + (UInt16(_), _) => false, + (UInt32(v1), UInt32(v2)) => v1.eq(v2), + (UInt32(_), _) => false, + (UInt64(v1), UInt64(v2)) => v1.eq(v2), + (UInt64(_), _) => false, + (Utf8(v1), Utf8(v2)) => v1.eq(v2), + (Utf8(_), _) => false, + (Null, Null) => true, + (Null, _) => false, + } + } +} + +impl PartialOrd for ScalarValue { + fn partial_cmp(&self, other: &Self) -> Option { + use ScalarValue::*; + match (self, other) { + (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), + (Boolean(_), _) => None, + (Float32(v1), Float32(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float32(_), _) => None, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float64(_), _) => None, + (Int8(v1), Int8(v2)) => v1.partial_cmp(v2), + (Int8(_), _) => None, + (Int16(v1), Int16(v2)) => v1.partial_cmp(v2), + (Int16(_), _) => None, + (Int32(v1), Int32(v2)) => v1.partial_cmp(v2), + (Int32(_), _) => None, + (Int64(v1), Int64(v2)) => v1.partial_cmp(v2), + (Int64(_), _) => None, + (UInt8(v1), UInt8(v2)) => v1.partial_cmp(v2), + (UInt8(_), _) => None, + (UInt16(v1), UInt16(v2)) => v1.partial_cmp(v2), + (UInt16(_), _) => None, + (UInt32(v1), UInt32(v2)) => v1.partial_cmp(v2), + (UInt32(_), _) => None, + (UInt64(v1), UInt64(v2)) => v1.partial_cmp(v2), + (UInt64(_), _) => None, + (Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2), + (Utf8(_), _) => None, + (Null, Null) => Some(Ordering::Equal), + (Null, _) => None, + } + } +} + +impl Eq for ScalarValue {} + +impl Hash for ScalarValue { + fn hash(&self, state: &mut H) { + use ScalarValue::*; + match self { + Boolean(v) => v.hash(state), + Float32(v) => { + let v = v.map(OrderedFloat); + v.hash(state) + } + Float64(v) => { + let v = v.map(OrderedFloat); + v.hash(state) + } + Int8(v) => v.hash(state), + Int16(v) => v.hash(state), + Int32(v) => v.hash(state), + Int64(v) => v.hash(state), + UInt8(v) => v.hash(state), + UInt16(v) => v.hash(state), + UInt32(v) => v.hash(state), + UInt64(v) => v.hash(state), + Utf8(v) => v.hash(state), + Null => 1.hash(state), + } + } +} + +macro_rules! typed_cast { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + ScalarValue::$SCALAR(match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), + }) + }}; +} + +macro_rules! build_array_from_option { + ($DATA_TYPE:ident, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ + match $EXPR { + Some(value) => Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)), + None => new_null_array(&DataType::$DATA_TYPE, $SIZE), + } + }}; + ($DATA_TYPE:ident, $ENUM:expr, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ + match $EXPR { + Some(value) => Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)), + None => new_null_array(&DataType::$DATA_TYPE($ENUM), $SIZE), + } + }}; + ($DATA_TYPE:ident, $ENUM:expr, $ENUM2:expr, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ + match $EXPR { + Some(value) => { + let array: ArrayRef = Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)); + // Need to call cast to cast to final data type with timezone/extra param + cast(&array, &DataType::$DATA_TYPE($ENUM, $ENUM2)).expect("cannot do temporal cast") + } + None => new_null_array(&DataType::$DATA_TYPE($ENUM, $ENUM2), $SIZE), + } + }}; +} + +impl ScalarValue { + pub fn new_none_value(data_type: &DataType) -> Result { + match data_type { + DataType::Null => Ok(ScalarValue::Null), + DataType::Boolean => Ok(ScalarValue::Boolean(None)), + DataType::Float32 => Ok(ScalarValue::Float32(None)), + DataType::Float64 => Ok(ScalarValue::Float64(None)), + DataType::Int8 => Ok(ScalarValue::Int8(None)), + DataType::Int16 => Ok(ScalarValue::Int16(None)), + DataType::Int32 => Ok(ScalarValue::Int32(None)), + DataType::Int64 => Ok(ScalarValue::Int64(None)), + DataType::UInt8 => Ok(ScalarValue::UInt8(None)), + DataType::UInt16 => Ok(ScalarValue::UInt16(None)), + DataType::UInt32 => Ok(ScalarValue::UInt32(None)), + DataType::UInt64 => Ok(ScalarValue::UInt64(None)), + DataType::Utf8 => Ok(ScalarValue::Utf8(None)), + other => Err(TypeError::NotImplementedArrowDataType(other.to_string())), + } + } + + /// Converts a value in `array` at `index` into a ScalarValue + pub fn try_from_array(array: &ArrayRef, index: usize) -> Result { + if !array.is_valid(index) { + return Self::new_none_value(array.data_type()); + } + + use arrow::array::*; + + Ok(match array.data_type() { + DataType::Null => ScalarValue::Null, + DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), + DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), + DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), + DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), + DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), + DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), + DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), + DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), + DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), + DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), + DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), + DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), + other => { + return Err(TypeError::NotImplementedArrowDataType(other.to_string())); + } + }) + } + + pub fn get_logical_type(&self) -> LogicalType { + match self { + ScalarValue::Null => LogicalType::Invalid, + ScalarValue::Boolean(_) => LogicalType::Boolean, + ScalarValue::Float32(_) => LogicalType::Float, + ScalarValue::Float64(_) => LogicalType::Double, + ScalarValue::Int8(_) => LogicalType::Tinyint, + ScalarValue::Int16(_) => LogicalType::Smallint, + ScalarValue::Int32(_) => LogicalType::Integer, + ScalarValue::Int64(_) => LogicalType::Bigint, + ScalarValue::UInt8(_) => LogicalType::UTinyint, + ScalarValue::UInt16(_) => LogicalType::USmallint, + ScalarValue::UInt32(_) => LogicalType::UInteger, + ScalarValue::UInt64(_) => LogicalType::UBigint, + ScalarValue::Utf8(_) => LogicalType::Varchar, + } + } + + /// Converts a scalar value into an 1-row array. + pub fn to_array(&self) -> ArrayRef { + self.to_array_of_size(1) + } + + /// Converts a scalar value into an array of `size` rows. + pub fn to_array_of_size(&self, size: usize) -> ArrayRef { + match self { + ScalarValue::Boolean(e) => Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef, + ScalarValue::Float64(e) => { + build_array_from_option!(Float64, Float64Array, e, size) + } + ScalarValue::Float32(e) => { + build_array_from_option!(Float32, Float32Array, e, size) + } + ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array, e, size), + ScalarValue::Int16(e) => build_array_from_option!(Int16, Int16Array, e, size), + ScalarValue::Int32(e) => build_array_from_option!(Int32, Int32Array, e, size), + ScalarValue::Int64(e) => build_array_from_option!(Int64, Int64Array, e, size), + ScalarValue::UInt8(e) => build_array_from_option!(UInt8, UInt8Array, e, size), + ScalarValue::UInt16(e) => { + build_array_from_option!(UInt16, UInt16Array, e, size) + } + ScalarValue::UInt32(e) => { + build_array_from_option!(UInt32, UInt32Array, e, size) + } + ScalarValue::UInt64(e) => { + build_array_from_option!(UInt64, UInt64Array, e, size) + } + + ScalarValue::Utf8(e) => match e { + Some(value) => Arc::new(StringArray::from_iter_values(repeat(value).take(size))), + None => new_null_array(&DataType::Utf8, size), + }, + ScalarValue::Null => new_null_array(&DataType::Null, size), + } + } +} + +macro_rules! impl_scalar { + ($ty:ty, $scalar:tt) => { + impl From<$ty> for ScalarValue { + fn from(value: $ty) -> Self { + ScalarValue::$scalar(Some(value)) + } + } + + impl From> for ScalarValue { + fn from(value: Option<$ty>) -> Self { + ScalarValue::$scalar(value) + } + } + }; +} + +impl_scalar!(f64, Float64); +impl_scalar!(f32, Float32); +impl_scalar!(i8, Int8); +impl_scalar!(i16, Int16); +impl_scalar!(i32, Int32); +impl_scalar!(i64, Int64); +impl_scalar!(bool, Boolean); +impl_scalar!(u8, UInt8); +impl_scalar!(u16, UInt16); +impl_scalar!(u32, UInt32); +impl_scalar!(u64, UInt64); +impl_scalar!(String, Utf8); + +impl From<&sqlparser::ast::Value> for ScalarValue { + fn from(v: &sqlparser::ast::Value) -> Self { + match v { + sqlparser::ast::Value::Number(n, _) => { + if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else { + panic!("unsupported number {:?}", n) + } + } + sqlparser::ast::Value::SingleQuotedString(s) => s.clone().into(), + sqlparser::ast::Value::DoubleQuotedString(s) => s.clone().into(), + sqlparser::ast::Value::Boolean(b) => (*b).into(), + sqlparser::ast::Value::Null => Self::Null, + _ => todo!("unsupported parsed scalar value {:?}", v), + } + } +} + +macro_rules! format_option { + ($F:expr, $EXPR:expr) => {{ + match $EXPR { + Some(e) => write!($F, "{}", e), + None => write!($F, "NULL"), + } + }}; +} + +impl fmt::Display for ScalarValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ScalarValue::Boolean(e) => format_option!(f, e)?, + ScalarValue::Float32(e) => format_option!(f, e)?, + ScalarValue::Float64(e) => format_option!(f, e)?, + ScalarValue::Int8(e) => format_option!(f, e)?, + ScalarValue::Int16(e) => format_option!(f, e)?, + ScalarValue::Int32(e) => format_option!(f, e)?, + ScalarValue::Int64(e) => format_option!(f, e)?, + ScalarValue::UInt8(e) => format_option!(f, e)?, + ScalarValue::UInt16(e) => format_option!(f, e)?, + ScalarValue::UInt32(e) => format_option!(f, e)?, + ScalarValue::UInt64(e) => format_option!(f, e)?, + ScalarValue::Utf8(e) => format_option!(f, e)?, + ScalarValue::Null => write!(f, "NULL")?, + }; + Ok(()) + } +} + +impl fmt::Debug for ScalarValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ScalarValue::Boolean(_) => write!(f, "Boolean({})", self), + ScalarValue::Float32(_) => write!(f, "Float32({})", self), + ScalarValue::Float64(_) => write!(f, "Float64({})", self), + ScalarValue::Int8(_) => write!(f, "Int8({})", self), + ScalarValue::Int16(_) => write!(f, "Int16({})", self), + ScalarValue::Int32(_) => write!(f, "Int32({})", self), + ScalarValue::Int64(_) => write!(f, "Int64({})", self), + ScalarValue::UInt8(_) => write!(f, "UInt8({})", self), + ScalarValue::UInt16(_) => write!(f, "UInt16({})", self), + ScalarValue::UInt32(_) => write!(f, "UInt32({})", self), + ScalarValue::UInt64(_) => write!(f, "UInt64({})", self), + ScalarValue::Utf8(None) => write!(f, "Utf8({})", self), + ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{}\")", self), + ScalarValue::Null => write!(f, "NULL"), + } + } +} diff --git a/src/util/mod.rs b/src/util/mod.rs index cafc406..be3adb0 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,10 +1,28 @@ -use arrow::datatypes::DataType; +use std::collections::HashMap; + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use arrow::util::display::array_value_to_string; use arrow::util::pretty::print_batches; use crate::optimizer::PlanNode; +use crate::types_v2::LogicalType; + +pub fn pretty_batches_with(batches: &[RecordBatch], names: &[String], types: &[LogicalType]) { + let fields = names + .iter() + .zip(types.iter()) + .map(|(name, data_type)| Field::new(name.as_str(), data_type.clone().into(), true)) + .collect::>(); + let schema = SchemaRef::new(Schema::new_with_metadata(fields, HashMap::new())); + let batches = batches + .iter() + .map(|batch| RecordBatch::try_new(schema.clone(), batch.columns().to_vec())) + .collect::, ArrowError>>() + .unwrap(); + _ = print_batches(batches.as_slice()); +} pub fn pretty_batches(batches: &Vec) { _ = print_batches(batches.as_slice()); diff --git a/tests/slt/aggregation.slt b/tests/slt/aggregation.slt index 3696333..4f36907 100644 --- a/tests/slt/aggregation.slt +++ b/tests/slt/aggregation.slt @@ -1,18 +1,22 @@ +skipif sqlrs_v2 query II select sum(salary) from employee ---- 33500 +skipif sqlrs_v2 query II select sum(salary), sum(id+1), count(id), count(salary) from employee where id > 1 ---- 21500 12 3 2 +skipif sqlrs_v2 query II select max(salary), min(id), max(last_name) from employee ---- 12000 1 Travis +skipif sqlrs_v2 query IIIII select salary, count(id), sum(salary), max(salary), min(salary) from employee group by salary ---- @@ -21,6 +25,7 @@ select salary, count(id), sum(salary), max(salary), min(salary) from employee gr 11500 1 11500 11500 11500 NULL 1 NULL NULL NULL +skipif sqlrs_v2 query IIIII select state, count(state), sum(salary), max(salary), min(salary) from employee group by state ---- @@ -28,6 +33,7 @@ CA 1 12000 12000 12000 CO 2 21500 11500 10000 (empty) 1 NULL NULL NULL +skipif sqlrs_v2 query IIIIII select state, id, count(state), sum(salary), max(salary), min(salary) from employee group by state, id ---- diff --git a/tests/slt/alias.slt b/tests/slt/alias.slt index a8901fb..3377880 100644 --- a/tests/slt/alias.slt +++ b/tests/slt/alias.slt @@ -1,14 +1,17 @@ # expression alias +skipif sqlrs_v2 query I select a as c1 from t1 order by c1 desc limit 1; ---- 2 +skipif sqlrs_v2 query I select a as c1 from t1 where c1 = 1; ---- 1 +skipif sqlrs_v2 query II select sum(b) as c1, a as c2 from t1 group by c2 order by c1 desc; ---- @@ -17,21 +20,25 @@ select sum(b) as c1, a as c2 from t1 group by c2 order by c1 desc; 4 0 # table alias +skipif sqlrs_v2 query I select t.a from t1 t where t.b > 1 order by t.a desc limit 1; ---- 2 +skipif sqlrs_v2 query I select sum(t.a) as c1 from t1 as t ---- 5 +skipif sqlrs_v2 query I select t.* from t1 t where t.b > 1 order by t.a desc limit 1; ---- 2 7 9 +skipif sqlrs_v2 query I select t_1.a from t1 t_1 left join t2 t_2 on t_1.a=t_2.b and t_1.c > t_2.c; ---- @@ -42,17 +49,19 @@ select t_1.a from t1 t_1 left join t2 t_2 on t_1.a=t_2.b and t_1.c > t_2.c; 2 # subquery alias +skipif sqlrs_v2 query I select t.a from (select * from t1 where a > 1) t where t.b > 7; ---- 2 +skipif sqlrs_v2 query III select t.* from (select * from t1 where a > 1) t where t.b > 7; ---- 2 8 1 - +skipif sqlrs_v2 query I select t.v1 + 1 from (select a + 1 as v1 from t1 where a > 1) t; ---- diff --git a/tests/slt/create_table.slt b/tests/slt/create_table.slt new file mode 100644 index 0000000..6a70df3 --- /dev/null +++ b/tests/slt/create_table.slt @@ -0,0 +1,19 @@ +onlyif sqlrs_v2 +statement ok +create table t1(a varchar, b varchar, c varchar); + +onlyif sqlrs_v2 +statement ok +insert into t1(c, b) values ('0','4'),('1','5'); + +onlyif sqlrs_v2 +statement ok +insert into t1 values ('2','7','9'); + +onlyif sqlrs_v2 +query III +select a, c, b from t1; +---- +NULL 0 4 +NULL 1 5 +2 9 7 diff --git a/tests/slt/distinct.slt b/tests/slt/distinct.slt index 6364114..937d409 100644 --- a/tests/slt/distinct.slt +++ b/tests/slt/distinct.slt @@ -1,3 +1,4 @@ +skipif sqlrs_v2 query I select distinct state from employee; ---- @@ -5,6 +6,7 @@ CA CO (empty) +skipif sqlrs_v2 query II select distinct a, b from t2; ---- @@ -13,16 +15,19 @@ select distinct a, b from t2; 30 3 40 4 +skipif sqlrs_v2 query I select sum(distinct b) from t2; ---- 9 +skipif sqlrs_v2 query I select sum(distinct(b)) from t2; ---- 9 +skipif sqlrs_v2 query I select sum(distinct(b)) from t2 group by c; ---- @@ -30,6 +35,7 @@ select sum(distinct(b)) from t2 group by c; 2 7 +skipif sqlrs_v2 query I select count(distinct(b)) from t2; ---- diff --git a/tests/slt/filter.slt b/tests/slt/filter.slt index cb8234b..897331e 100644 --- a/tests/slt/filter.slt +++ b/tests/slt/filter.slt @@ -1,14 +1,17 @@ +skipif sqlrs_v2 query II select first_name from employee where id > 2 ---- John Von +skipif sqlrs_v2 query II select id, first_name from employee where id > 2 and id < 4 ---- 3 John +skipif sqlrs_v2 query II select id, first_name from employee where id > 3 or id = 1 ---- diff --git a/tests/slt/join.slt b/tests/slt/join.slt index 64da005..7f19abf 100644 --- a/tests/slt/join.slt +++ b/tests/slt/join.slt @@ -1,3 +1,4 @@ +skipif sqlrs_v2 query III select employee.id, employee.first_name, employee.department_id, department.department_name, department.id from employee left join department on employee.department_id = department.id; @@ -7,6 +8,7 @@ from employee left join department on employee.department_id = department.id; 3 John 4 Engineering 4 4 Von NULL NULL NULL +skipif sqlrs_v2 query III select employee.id, employee.first_name, employee.department_id, department.department_name, department.id from employee right join department on employee.department_id = department.id; @@ -16,6 +18,7 @@ from employee right join department on employee.department_id = department.id; NULL NULL NULL Finance 3 3 John 4 Engineering 4 +skipif sqlrs_v2 query III select employee.id, employee.first_name, employee.department_id, department.department_name, department.id from employee inner join department on employee.department_id = department.id; @@ -24,6 +27,7 @@ from employee inner join department on employee.department_id = department.id; 2 Gregg 2 Marketing 2 3 John 4 Engineering 4 +skipif sqlrs_v2 query III select employee.id, employee.first_name, employee.department_id, department.department_name, department.id from employee full join department on employee.department_id = department.id; @@ -35,6 +39,7 @@ NULL NULL NULL Finance 3 4 Von NULL NULL NULL +skipif sqlrs_v2 query IIIII select employee.id, employee.first_name, department.department_name, state.state_name, state.state_code from employee left join department on employee.department_id=department.id @@ -45,6 +50,7 @@ right join state on state.state_code=employee.state; 3 John Engineering Colorado State CO NULL NULL NULL New Jersey NJ +skipif sqlrs_v2 query IIIII select employee.id, employee.first_name, department.department_name, state.state_name, state.state_code from employee left join department on employee.department_id=department.id @@ -55,6 +61,7 @@ left join state on state.state_code=employee.state; 3 John Engineering Colorado State CO 4 Von NULL NULL NULL +skipif sqlrs_v2 query IIIII select employee.id, employee.first_name, department.department_name, state.state_name, state.state_code from employee left join department on employee.department_id=department.id @@ -64,6 +71,7 @@ inner join state on state.state_code=employee.state; 2 Gregg Marketing Colorado State CO 3 John Engineering Colorado State CO +skipif sqlrs_v2 query IIIII select employee.id, employee.first_name, department.department_name, state.state_name, state.state_code from employee left join department on employee.department_id=department.id @@ -76,6 +84,7 @@ NULL NULL NULL New Jersey NJ 4 Von NULL NULL NULL +skipif sqlrs_v2 query IIIIII select t1.*, t2.* from t1, t2 where t1.a = 0; ---- @@ -84,6 +93,7 @@ select t1.*, t2.* from t1, t2 where t1.a = 0; 0 4 7 30 3 6 0 4 7 40 4 6 +skipif sqlrs_v2 query IIIIII select t1.*, t2.* from t1 cross join t2 where t1.a = 0; ---- diff --git a/tests/slt/join_filter.slt b/tests/slt/join_filter.slt index eea7493..33ce99e 100644 --- a/tests/slt/join_filter.slt +++ b/tests/slt/join_filter.slt @@ -1,3 +1,4 @@ +skipif sqlrs_v2 query III select employee.id, employee.first_name, employee.state, state.state_name from employee left join state on employee.state=state.state_code and state.state_name!='California State'; @@ -7,6 +8,7 @@ from employee left join state on employee.state=state.state_code and state.state 1 Bill CA NULL 4 Von (empty) NULL +skipif sqlrs_v2 query IIIIII select t1.*, t2.* from t1 inner join t2 on t1.a=t2.b; ---- @@ -15,12 +17,14 @@ select t1.*, t2.* from t1 inner join t2 on t1.a=t2.b; 2 7 9 20 2 5 2 8 1 20 2 5 +skipif sqlrs_v2 query IIIIII select t1.*, t2.* from t1 inner join t2 on t1.a=t2.b and t1.c > t2.c; ---- 2 7 9 10 2 7 2 7 9 20 2 5 +skipif sqlrs_v2 query IIIIII select t1.*, t2.* from t1 left join t2 on t1.a=t2.b; ---- @@ -31,6 +35,7 @@ select t1.*, t2.* from t1 left join t2 on t1.a=t2.b; 0 4 7 NULL NULL NULL 1 5 8 NULL NULL NULL +skipif sqlrs_v2 query IIIIII select t1.*, t2.* from t1 left join t2 on t1.a=t2.b and t1.c > t2.c; ---- @@ -40,6 +45,7 @@ select t1.*, t2.* from t1 left join t2 on t1.a=t2.b and t1.c > t2.c; 1 5 8 NULL NULL NULL 2 8 1 NULL NULL NULL +skipif sqlrs_v2 query IIIIII select t1.*, t2.* from t1 right join t2 on t1.a=t2.b; ---- @@ -50,6 +56,7 @@ select t1.*, t2.* from t1 right join t2 on t1.a=t2.b; NULL NULL NULL 30 3 6 NULL NULL NULL 40 4 6 +skipif sqlrs_v2 query IIIIII select t1.*, t2.* from t1 right join t2 on t1.a=t2.b and t1.c > t2.c; ---- @@ -58,6 +65,7 @@ select t1.*, t2.* from t1 right join t2 on t1.a=t2.b and t1.c > t2.c; NULL NULL NULL 30 3 6 NULL NULL NULL 40 4 6 +skipif sqlrs_v2 query IIIIII select t1.*, t2.* from t1 full join t2 on t1.a=t2.b; ---- @@ -70,6 +78,7 @@ NULL NULL NULL 40 4 6 0 4 7 NULL NULL NULL 1 5 8 NULL NULL NULL +skipif sqlrs_v2 query IIIIII select t1.*, t2.* from t1 full join t2 on t1.a=t2.b and t1.c > t2.c; ---- diff --git a/tests/slt/limit.slt b/tests/slt/limit.slt index 404137c..ca6fa22 100644 --- a/tests/slt/limit.slt +++ b/tests/slt/limit.slt @@ -1,23 +1,28 @@ +skipif sqlrs_v2 query II select id from employee limit 2 offset 1 ---- 2 3 +skipif sqlrs_v2 query II select id from employee limit 1 offset 10 ---- +skipif sqlrs_v2 query II select id from employee limit 0 offset 0 ---- +skipif sqlrs_v2 query II select id from employee offset 2 ---- 3 4 +skipif sqlrs_v2 query II select id from employee limit 2 ---- diff --git a/tests/slt/order.slt b/tests/slt/order.slt index 7bc19f1..8738ec6 100644 --- a/tests/slt/order.slt +++ b/tests/slt/order.slt @@ -1,8 +1,10 @@ +skipif sqlrs_v2 query I select id from employee order by id desc offset 2 limit 1; ---- 2 +skipif sqlrs_v2 query II select id, state from employee order by state, id desc ---- @@ -11,6 +13,7 @@ select id, state from employee order by state, id desc 3 CO 2 CO +skipif sqlrs_v2 query I select id from employee order by first_name desc offset 2 limit 1; ---- diff --git a/tests/slt/select.slt b/tests/slt/select.slt index d975a4d..58734cb 100644 --- a/tests/slt/select.slt +++ b/tests/slt/select.slt @@ -1,3 +1,4 @@ +skipif sqlrs_v2 query IIII select first_name, state, id, salary from employee ---- diff --git a/tests/slt/subquery.slt b/tests/slt/subquery.slt index 8eb75ca..3dbb104 100644 --- a/tests/slt/subquery.slt +++ b/tests/slt/subquery.slt @@ -1,30 +1,36 @@ # subquery as source # subquery in FROM must have an alias. same behavior as Postgres +skipif sqlrs_v2 statement error select * from (select * from t1 where a > 1) where b > 7; +skipif sqlrs_v2 query III select * from (select * from t1 where c < 2) t_1; ---- 2 8 1 +skipif sqlrs_v2 query III select * from (select * from (select * from t1 where c < 2) t_1 where t_1.a > 1) t_2 where t_2.b > 7; ---- 2 8 1 +skipif sqlrs_v2 query III select t.* from (select * from t1 where a > 1) t where t.b > 7; ---- 2 8 1 +skipif sqlrs_v2 query II select t.b from (select a, b from t1 where a > 1) t where t.b > 7; ---- 8 +skipif sqlrs_v2 query III select t_2.* from (select t_1.* from (select * from t1 where c < 2) t_1 where t_1.a > 1) t_2 where t_2.b > 7; ---- @@ -32,6 +38,7 @@ select t_2.* from (select t_1.* from (select * from t1 where c < 2) t_1 where t_ # scalar subquery +skipif sqlrs_v2 query II select a, t2.v1 as max_b from t1 cross join (select max(b) as v1 from t1) t2; ---- @@ -40,6 +47,7 @@ select a, t2.v1 as max_b from t1 cross join (select max(b) as v1 from t1) t2; 2 8 2 8 +skipif sqlrs_v2 query II select a, (select max(b) from t1) max_b from t1; ---- @@ -48,6 +56,7 @@ select a, (select max(b) from t1) max_b from t1; 2 8 2 8 +skipif sqlrs_v2 query II select a, (select max(b) from t1) from t1; ---- @@ -56,6 +65,7 @@ select a, (select max(b) from t1) from t1; 2 8 2 8 +skipif sqlrs_v2 query II select a, (select max(b) from t1) + 2 as max_b from t1; ---- @@ -64,6 +74,7 @@ select a, (select max(b) from t1) + 2 as max_b from t1; 2 10 2 10 +skipif sqlrs_v2 query II select a, (select max(b) from t1) + (select min(b) from t1) as mix_b from t1; ---- @@ -72,12 +83,14 @@ select a, (select max(b) from t1) + (select min(b) from t1) as mix_b from t1; 2 12 2 12 +skipif sqlrs_v2 query I select t1.a, t1.b from t1 where a >= (select max(a) from t1); ---- 2 7 2 8 +skipif sqlrs_v2 query I select t1.a, t1.b from t1 where a >= (select max(a) from t1) and b = (select max(b) from t1); ---- diff --git a/tests/sqllogictest/src/lib.rs b/tests/sqllogictest/src/lib.rs index b6fd323..efba3de 100644 --- a/tests/sqllogictest/src/lib.rs +++ b/tests/sqllogictest/src/lib.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use sqllogictest::{AsyncDB, Runner}; use sqlrs::db::{Database, DatabaseError}; +use sqlrs::main_entry::{ClientContext, DatabaseError as DatabaseErrorV2, DatabaseInstance}; use sqlrs::util::record_batch_to_string; fn init_tables(db: Arc) { @@ -42,3 +43,29 @@ impl AsyncDB for DatabaseWrapper { Ok(output?) } } + +struct DatabaseWrapperV2 { + client_context: Arc, +} + +#[async_trait::async_trait] +impl AsyncDB for DatabaseWrapperV2 { + type Error = DatabaseErrorV2; + + async fn run(&mut self, sql: &str) -> Result { + let chunks = self.client_context.query(sql.to_string()).await?; + let output = chunks.iter().map(record_batch_to_string).try_collect(); + Ok(output?) + } + + fn engine_name(&self) -> &str { + "sqlrs_v2" + } +} + +pub fn test_run_v2(sqlfile: &str) { + let dbv2 = Arc::new(DatabaseInstance::default()); + let client_context = ClientContext::new(dbv2); + let mut tester = Runner::new(DatabaseWrapperV2 { client_context }); + tester.run_file(sqlfile).unwrap() +} diff --git a/tests/sqllogictest/tests/sqllogictest.rs b/tests/sqllogictest/tests/sqllogictest.rs index 32d6ee6..7295e7d 100644 --- a/tests/sqllogictest/tests/sqllogictest.rs +++ b/tests/sqllogictest/tests/sqllogictest.rs @@ -1,5 +1,5 @@ use libtest_mimic::{Arguments, Trial}; -use sqllogictest_test::test_run; +use sqllogictest_test::{test_run, test_run_v2}; fn main() { const SLT_PATTERN: &str = "../slt/**/*.slt"; @@ -20,6 +20,7 @@ fn main() { let test = Trial::test(filename, move || { test_run(filepath.as_str()); + test_run_v2(filepath.as_str()); Ok(()) });