diff --git a/src/catalog_v2/catalog.rs b/src/catalog_v2/catalog.rs index 30b22c6..f71c857 100644 --- a/src/catalog_v2/catalog.rs +++ b/src/catalog_v2/catalog.rs @@ -1,7 +1,8 @@ use std::sync::Arc; use super::entry::{CatalogEntry, DataTable}; -use super::{CatalogError, CatalogSet, TableCatalogEntry}; +use super::{CatalogError, CatalogSet, TableCatalogEntry, TableFunctionCatalogEntry}; +use crate::common::CreateTableFunctionInfo; use crate::main_entry::ClientContext; /// The Catalog object represents the catalog of the database. @@ -57,4 +58,55 @@ impl Catalog { } Err(CatalogError::CatalogEntryTypeNotMatch) } + + pub fn create_table_function( + client_context: Arc, + info: CreateTableFunctionInfo, + ) -> Result<(), CatalogError> { + let mut catalog = match client_context.db.catalog.try_write() { + Ok(c) => c, + Err(_) => return Err(CatalogError::CatalogLockedError), + }; + if let CatalogEntry::SchemaCatalogEntry(mut entry) = + catalog.schemas.get_entry(info.base.schema.clone())? + { + catalog.catalog_version += 1; + let schema = info.base.schema.clone(); + entry.create_table_function(catalog.catalog_version, info)?; + catalog + .schemas + .replace_entry(schema, CatalogEntry::SchemaCatalogEntry(entry))?; + return Ok(()); + } + Err(CatalogError::CatalogEntryTypeNotMatch) + } + + pub fn scan_entries( + client_context: Arc, + callback: &F, + ) -> Result, CatalogError> + where + F: Fn(&CatalogEntry) -> bool, + { + let catalog = match client_context.db.catalog.try_read() { + Ok(c) => c, + Err(_) => return Err(CatalogError::CatalogLockedError), + }; + Ok(catalog.schemas.scan_entries(callback)) + } + + pub fn get_table_function( + client_context: Arc, + schema: String, + table_function: String, + ) -> Result { + let catalog = match client_context.db.catalog.try_read() { + Ok(c) => c, + Err(_) => return Err(CatalogError::CatalogLockedError), + }; + if let CatalogEntry::SchemaCatalogEntry(entry) = catalog.schemas.get_entry(schema)? { + return entry.get_table_function(table_function); + } + Err(CatalogError::CatalogEntryTypeNotMatch) + } } diff --git a/src/catalog_v2/catalog_set.rs b/src/catalog_v2/catalog_set.rs index ccac973..f5c2cf3 100644 --- a/src/catalog_v2/catalog_set.rs +++ b/src/catalog_v2/catalog_set.rs @@ -46,4 +46,17 @@ impl CatalogSet { } Err(CatalogError::CatalogEntryNotExists(name)) } + + pub fn scan_entries(&self, callback: &F) -> Vec + where + F: Fn(&CatalogEntry) -> bool, + { + let mut result = vec![]; + for (_, entry) in self.entries.iter() { + if callback(entry) { + result.push(entry.clone()); + } + } + result + } } diff --git a/src/catalog_v2/entry/mod.rs b/src/catalog_v2/entry/mod.rs index 0233b3e..ce1ce85 100644 --- a/src/catalog_v2/entry/mod.rs +++ b/src/catalog_v2/entry/mod.rs @@ -1,14 +1,17 @@ mod schema_catalog_entry; mod table_catalog_entry; +mod table_function_catalog_entry; use derive_new::new; pub use schema_catalog_entry::*; pub use table_catalog_entry::*; +pub use table_function_catalog_entry::*; #[derive(Clone, Debug)] pub enum CatalogEntry { SchemaCatalogEntry(SchemaCatalogEntry), TableCatalogEntry(TableCatalogEntry), + TableFunctionCatalogEntry(TableFunctionCatalogEntry), } impl CatalogEntry { @@ -21,7 +24,7 @@ impl CatalogEntry { #[derive(new, Clone, Debug)] pub struct CatalogEntryBase { /// The object identifier of the entry - oid: usize, + pub(crate) oid: usize, /// The name of the entry - name: String, + pub(crate) name: String, } diff --git a/src/catalog_v2/entry/schema_catalog_entry.rs b/src/catalog_v2/entry/schema_catalog_entry.rs index 25f3c07..0dac53d 100644 --- a/src/catalog_v2/entry/schema_catalog_entry.rs +++ b/src/catalog_v2/entry/schema_catalog_entry.rs @@ -1,12 +1,14 @@ use super::table_catalog_entry::{DataTable, TableCatalogEntry}; -use super::{CatalogEntry, CatalogEntryBase}; +use super::{CatalogEntry, CatalogEntryBase, TableFunctionCatalogEntry}; use crate::catalog_v2::{CatalogError, CatalogSet}; +use crate::common::CreateTableFunctionInfo; #[allow(dead_code)] #[derive(Clone, Debug)] pub struct SchemaCatalogEntry { base: CatalogEntryBase, tables: CatalogSet, + functions: CatalogSet, } impl SchemaCatalogEntry { @@ -14,6 +16,7 @@ impl SchemaCatalogEntry { Self { base: CatalogEntryBase::new(oid, schema), tables: CatalogSet::default(), + functions: CatalogSet::default(), } } @@ -23,8 +26,12 @@ impl SchemaCatalogEntry { table: String, storage: DataTable, ) -> Result<(), CatalogError> { - let entry = - CatalogEntry::TableCatalogEntry(TableCatalogEntry::new(oid, table.clone(), storage)); + let entry = CatalogEntry::TableCatalogEntry(TableCatalogEntry::new( + oid, + table.clone(), + self.base.clone(), + storage, + )); self.tables.create_entry(table, entry)?; Ok(()) } @@ -35,4 +42,28 @@ impl SchemaCatalogEntry { _ => Err(CatalogError::CatalogEntryNotExists(table)), } } + + pub fn create_table_function( + &mut self, + oid: usize, + info: CreateTableFunctionInfo, + ) -> Result<(), CatalogError> { + let entry = TableFunctionCatalogEntry::new( + CatalogEntryBase::new(oid, info.name.clone()), + info.functions, + ); + let entry = CatalogEntry::TableFunctionCatalogEntry(entry); + self.functions.create_entry(info.name, entry)?; + Ok(()) + } + + pub fn get_table_function( + &self, + table_function: String, + ) -> Result { + match self.functions.get_entry(table_function.clone())? { + CatalogEntry::TableFunctionCatalogEntry(e) => Ok(e), + _ => Err(CatalogError::CatalogEntryNotExists(table_function)), + } + } } diff --git a/src/catalog_v2/entry/table_catalog_entry.rs b/src/catalog_v2/entry/table_catalog_entry.rs index 097b547..b908d38 100644 --- a/src/catalog_v2/entry/table_catalog_entry.rs +++ b/src/catalog_v2/entry/table_catalog_entry.rs @@ -9,6 +9,7 @@ use crate::types_v2::LogicalType; #[derive(Clone, Debug)] pub struct TableCatalogEntry { pub(crate) base: CatalogEntryBase, + pub(crate) schema_base: CatalogEntryBase, pub(crate) storage: DataTable, /// A list of columns that are part of this table pub(crate) columns: Vec, @@ -17,7 +18,12 @@ pub struct TableCatalogEntry { } impl TableCatalogEntry { - pub fn new(oid: usize, table: String, storage: DataTable) -> Self { + pub fn new( + oid: usize, + table: String, + schema_base: CatalogEntryBase, + storage: DataTable, + ) -> Self { let mut name_map = HashMap::new(); let mut columns = vec![]; storage @@ -30,6 +36,7 @@ impl TableCatalogEntry { }); Self { base: CatalogEntryBase::new(oid, table), + schema_base, storage, columns, name_map, diff --git a/src/catalog_v2/entry/table_function_catalog_entry.rs b/src/catalog_v2/entry/table_function_catalog_entry.rs new file mode 100644 index 0000000..6833e83 --- /dev/null +++ b/src/catalog_v2/entry/table_function_catalog_entry.rs @@ -0,0 +1,12 @@ +use derive_new::new; + +use super::CatalogEntryBase; +use crate::function::TableFunction; + +#[derive(new, Clone, Debug)] +pub struct TableFunctionCatalogEntry { + #[allow(dead_code)] + pub(crate) base: CatalogEntryBase, + #[allow(dead_code)] + pub(crate) functions: Vec, +} diff --git a/src/planner_v2/binder/statement/create_info.rs b/src/common/create_info.rs similarity index 59% rename from src/planner_v2/binder/statement/create_info.rs rename to src/common/create_info.rs index ebb85d5..5508454 100644 --- a/src/planner_v2/binder/statement/create_info.rs +++ b/src/common/create_info.rs @@ -1,6 +1,12 @@ use derive_new::new; use crate::catalog_v2::ColumnDefinition; +use crate::function::TableFunction; + +#[derive(new, Debug, Clone)] +pub struct CreateInfoBase { + pub(crate) schema: String, +} #[derive(new, Debug, Clone)] pub struct CreateTableInfo { @@ -11,7 +17,11 @@ pub struct CreateTableInfo { pub(crate) columns: Vec, } -#[derive(new, Debug, Clone)] -pub struct CreateInfoBase { - pub(crate) schema: String, +#[derive(new)] +pub struct CreateTableFunctionInfo { + pub(crate) base: CreateInfoBase, + /// Function name + pub(crate) name: String, + /// Functions with different arguments + pub(crate) functions: Vec, } diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000..b122613 --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,3 @@ +mod create_info; + +pub use create_info::*; diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 5f61378..f15c0c2 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -16,6 +16,7 @@ pub use util::*; pub use volcano_executor::*; use crate::catalog_v2::CatalogError; +use crate::function::FunctionError; use crate::main_entry::ClientContext; use crate::types_v2::TypeError; @@ -52,6 +53,12 @@ pub enum ExecutorError { #[from] TypeError, ), + #[error("function error: {0}")] + FunctionError( + #[source] + #[from] + FunctionError, + ), #[error("Executor internal error: {0}")] InternalError(String), } diff --git a/src/execution/physical_plan/physical_table_scan.rs b/src/execution/physical_plan/physical_table_scan.rs index 6117b63..325a294 100644 --- a/src/execution/physical_plan/physical_table_scan.rs +++ b/src/execution/physical_plan/physical_table_scan.rs @@ -1,15 +1,16 @@ use derive_new::new; use super::{PhysicalOperator, PhysicalOperatorBase}; -use crate::catalog_v2::TableCatalogEntry; use crate::execution::PhysicalPlanGenerator; +use crate::function::{FunctionData, TableFunction}; use crate::planner_v2::LogicalGet; use crate::types_v2::LogicalType; #[derive(new, Clone)] pub struct PhysicalTableScan { pub(crate) base: PhysicalOperatorBase, - pub(crate) bind_table: TableCatalogEntry, + pub(crate) function: TableFunction, + pub(crate) bind_data: FunctionData, /// The types of ALL columns that can be returned by the table function pub(crate) returned_types: Vec, /// The names of ALL columns that can be returned by the table function @@ -19,7 +20,8 @@ pub struct PhysicalTableScan { impl PhysicalPlanGenerator { pub(crate) fn create_physical_table_scan(&self, op: LogicalGet) -> PhysicalOperator { let base = PhysicalOperatorBase::new(vec![], op.base.types); - let plan = PhysicalTableScan::new(base, op.bind_table, op.returned_types, op.names); + let plan = + PhysicalTableScan::new(base, op.function, op.bind_data, op.returned_types, op.names); PhysicalOperator::PhysicalTableScan(plan) } } diff --git a/src/execution/volcano_executor/table_scan.rs b/src/execution/volcano_executor/table_scan.rs index b67f2a2..261a74f 100644 --- a/src/execution/volcano_executor/table_scan.rs +++ b/src/execution/volcano_executor/table_scan.rs @@ -5,7 +5,9 @@ use derive_new::new; use futures_async_stream::try_stream; use crate::execution::{ExecutionContext, ExecutorError, PhysicalTableScan, SchemaUtil}; -use crate::storage_v2::LocalStorage; +use crate::function::{ + GlobalTableFunctionState, SeqTableScanInitInput, TableFunctionInitInput, TableFunctionInput, +}; #[derive(new)] pub struct TableScan { @@ -17,10 +19,22 @@ impl TableScan { pub async fn execute(self, context: Arc) { let schema = SchemaUtil::new_schema_ref(&self.plan.names, &self.plan.returned_types); - let table = self.plan.bind_table; - let mut local_storage_reader = LocalStorage::create_reader(&table.storage); - let client_context = context.clone_client_context(); - while let Some(batch) = local_storage_reader.next_batch(client_context.clone()) { + let bind_data = self.plan.bind_data; + + let table_scan_func = self.plan.function.function; + let global_state = if let Some(init_global_func) = self.plan.function.init_global { + let seq_table_scan_init_input = TableFunctionInitInput::SeqTableScanInitInput( + Box::new(SeqTableScanInitInput::new(bind_data.clone())), + ); + init_global_func(context.clone_client_context(), seq_table_scan_init_input)? + } else { + GlobalTableFunctionState::None + }; + + let mut tabel_scan_input = TableFunctionInput::new(bind_data, global_state); + while let Some(batch) = + table_scan_func(context.clone_client_context(), &mut tabel_scan_input)? + { let columns = batch.columns().to_vec(); yield RecordBatch::try_new(schema.clone(), columns)? } diff --git a/src/function/errors.rs b/src/function/errors.rs new file mode 100644 index 0000000..8f2746a --- /dev/null +++ b/src/function/errors.rs @@ -0,0 +1,28 @@ +use arrow::error::ArrowError; + +use crate::catalog_v2::CatalogError; +use crate::types_v2::TypeError; + +#[derive(thiserror::Error, Debug)] +pub enum FunctionError { + #[error("catalog error: {0}")] + CatalogError( + #[from] + #[source] + CatalogError, + ), + #[error("type error: {0}")] + TypeError( + #[from] + #[source] + TypeError, + ), + #[error("arrow error: {0}")] + ArrowError( + #[from] + #[source] + ArrowError, + ), + #[error("Internal error: {0}")] + InternalError(String), +} diff --git a/src/function/mod.rs b/src/function/mod.rs new file mode 100644 index 0000000..b0e4d4a --- /dev/null +++ b/src/function/mod.rs @@ -0,0 +1,11 @@ +mod errors; +mod table; + +pub use errors::*; +pub use table::*; + +#[derive(Debug, Clone)] +pub enum FunctionData { + SeqTableScanInputData(Box), + None, +} diff --git a/src/function/table/mod.rs b/src/function/table/mod.rs new file mode 100644 index 0000000..c61bd1b --- /dev/null +++ b/src/function/table/mod.rs @@ -0,0 +1,4 @@ +mod seq_table_scan; +mod table_function; +pub use seq_table_scan::*; +pub use table_function::*; diff --git a/src/function/table/seq_table_scan.rs b/src/function/table/seq_table_scan.rs new file mode 100644 index 0000000..22eee15 --- /dev/null +++ b/src/function/table/seq_table_scan.rs @@ -0,0 +1,72 @@ +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; +use derive_new::new; + +use super::{TableFunction, TableFunctionBindInput, TableFunctionInput}; +use crate::catalog_v2::TableCatalogEntry; +use crate::function::{FunctionData, FunctionError}; +use crate::main_entry::ClientContext; +use crate::storage_v2::{LocalStorage, LocalStorageReader}; + +/// The table scan function represents a sequential scan over one of base tables. +pub struct SeqTableScan; + +#[derive(new, Debug, Clone)] +pub struct SeqTableScanInputData { + pub(crate) bind_table: TableCatalogEntry, + pub(crate) local_storage_reader: LocalStorageReader, +} + +#[derive(new)] +pub struct SeqTableScanBindInput { + pub(crate) bind_table: TableCatalogEntry, +} + +#[derive(new)] +pub struct SeqTableScanInitInput { + #[allow(dead_code)] + pub(crate) bind_data: FunctionData, +} + +impl SeqTableScan { + fn seq_table_scan_bind_func( + input: TableFunctionBindInput, + ) -> Result, FunctionError> { + if let TableFunctionBindInput::SeqTableScanBindInput(bind_input) = input { + let table = bind_input.bind_table; + let res = FunctionData::SeqTableScanInputData(Box::new(SeqTableScanInputData::new( + table.clone(), + LocalStorage::create_reader(&table.storage), + ))); + Ok(Some(res)) + } else { + Err(FunctionError::InternalError( + "unexpected bind data type".to_string(), + )) + } + } + + fn seq_table_scan_func( + context: Arc, + input: &mut TableFunctionInput, + ) -> Result, FunctionError> { + if let FunctionData::SeqTableScanInputData(data) = &mut input.bind_data { + let batch = data.local_storage_reader.next_batch(context); + Ok(batch) + } else { + Err(FunctionError::InternalError( + "unexpected bind data type".to_string(), + )) + } + } + + pub fn get_function() -> TableFunction { + TableFunction::new( + "seq_table_scan".to_string(), + Some(Self::seq_table_scan_bind_func), + None, + Self::seq_table_scan_func, + ) + } +} diff --git a/src/function/table/table_function.rs b/src/function/table/table_function.rs new file mode 100644 index 0000000..dd2f72d --- /dev/null +++ b/src/function/table/table_function.rs @@ -0,0 +1,67 @@ +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; +use derive_new::new; + +use super::{SeqTableScanBindInput, SeqTableScanInitInput}; +use crate::function::{FunctionData, FunctionError}; +use crate::main_entry::ClientContext; + +pub enum GlobalTableFunctionState { + None, +} + +pub enum TableFunctionBindInput { + SeqTableScanBindInput(Box), + None, +} + +#[derive(new)] +pub struct TableFunctionInput { + pub(crate) bind_data: FunctionData, + #[allow(dead_code)] + pub(crate) global_state: GlobalTableFunctionState, +} + +pub enum TableFunctionInitInput { + SeqTableScanInitInput(Box), + None, +} + +pub type TableFunctionBindFunc = + fn(TableFunctionBindInput) -> Result, FunctionError>; + +pub type TableFunc = + fn(Arc, &mut TableFunctionInput) -> Result, FunctionError>; + +pub type TableFunctionInitGlobalFunc = fn( + Arc, + TableFunctionInitInput, +) -> Result; + +#[derive(new, Clone)] +pub struct TableFunction { + // The name of the function + pub(crate) name: String, + /// Bind function + /// This function is used for determining the return type of a table producing function and + /// returning bind data The returned FunctionData object should be constant and should not + /// be changed during execution. + pub(crate) bind: Option, + /// (Optional) global init function + /// Initialize the global operator state of the function. + /// The global operator state is used to keep track of the progress in the table function and + /// is shared between all threads working on the table function. + pub(crate) init_global: Option, + /// The main function + pub(crate) function: TableFunc, +} + +impl Debug for TableFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TableFunction") + .field("name", &self.name) + .finish() + } +} diff --git a/src/lib.rs b/src/lib.rs index 7cc7ca3..d35b889 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,9 +11,11 @@ pub mod binder; pub mod catalog; pub mod catalog_v2; pub mod cli; +pub mod common; pub mod db; pub mod execution; pub mod executor; +pub mod function; pub mod main_entry; pub mod optimizer; pub mod parser; diff --git a/src/main.rs b/src/main.rs index c1a6b7a..5f11111 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,6 +16,7 @@ async fn main() -> Result<()> { create_csv_table(&db, "t2")?; let dbv2 = Arc::new(DatabaseInstance::default()); + dbv2.initialize()?; let client_context = ClientContext::new(dbv2); cli::interactive(db, client_context).await?; diff --git a/src/main_entry/db.rs b/src/main_entry/db.rs index daaad42..d53cb71 100644 --- a/src/main_entry/db.rs +++ b/src/main_entry/db.rs @@ -1,20 +1,32 @@ use std::sync::{Arc, RwLock}; -use crate::catalog_v2::{Catalog, DEFAULT_SCHEMA}; +use super::DatabaseError; +use crate::catalog_v2::{Catalog, CatalogError, DEFAULT_SCHEMA}; use crate::storage_v2::LocalStorage; +#[derive(Default)] pub struct DatabaseInstance { pub(crate) storage: RwLock, pub(crate) catalog: Arc>, } -impl Default for DatabaseInstance { - fn default() -> Self { - let mut catalog = Catalog::default(); +impl DatabaseInstance { + pub fn initialize(self: &Arc) -> Result<(), DatabaseError> { + // Create the default schema: main + self.init_default_schema()?; + Ok(()) + } + + fn init_default_schema(self: &Arc) -> Result<(), DatabaseError> { + let mut catalog = match self.catalog.try_write() { + Ok(c) => c, + Err(_) => { + return Err(DatabaseError::CatalogError( + CatalogError::CatalogLockedError, + )) + } + }; catalog.create_schema(DEFAULT_SCHEMA.to_string()).unwrap(); - Self { - storage: RwLock::new(LocalStorage::default()), - catalog: Arc::new(RwLock::new(catalog)), - } + Ok(()) } } diff --git a/src/main_entry/errors.rs b/src/main_entry/errors.rs index fa71b4b..9a44e6d 100644 --- a/src/main_entry/errors.rs +++ b/src/main_entry/errors.rs @@ -1,7 +1,9 @@ use arrow::error::ArrowError; use sqlparser::parser::ParserError; +use crate::catalog_v2::CatalogError; use crate::execution::ExecutorError; +use crate::function::FunctionError; use crate::planner_v2::PlannerError; #[derive(thiserror::Error, Debug)] @@ -12,6 +14,12 @@ pub enum DatabaseError { #[from] ParserError, ), + #[error("catalog error: {0}")] + CatalogError( + #[source] + #[from] + CatalogError, + ), #[error("planner error: {0}")] PlannerError( #[source] @@ -30,6 +38,12 @@ pub enum DatabaseError { #[from] ArrowError, ), + #[error("Function error: {0}")] + FunctionError( + #[source] + #[from] + FunctionError, + ), #[error("Internal error: {0}")] InternalError(String), } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 814566d..b8cafd4 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1,4 +1,4 @@ -use sqlparser::ast::Statement; +use sqlparser::ast::{Query, Statement}; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::{Parser, ParserError}; @@ -19,4 +19,20 @@ impl Sqlparser { let stmts = Parser::parse_sql(&dialect, sql.as_str())?; Ok(stmts) } + + pub fn parse_one_query(sql: String) -> Result, ParserError> { + let dialect = PostgreSqlDialect {}; + let stmts = Parser::parse_sql(&dialect, sql.as_str())?; + if stmts.len() != 1 { + return Err(ParserError::ParserError( + "not a single statement".to_string(), + )); + } + match stmts[0].clone() { + Statement::Query(q) => Ok(q), + _ => Err(ParserError::ParserError( + "only expect query statement".to_string(), + )), + } + } } diff --git a/src/planner_v2/binder/mod.rs b/src/planner_v2/binder/mod.rs index cdd1f99..2722cf3 100644 --- a/src/planner_v2/binder/mod.rs +++ b/src/planner_v2/binder/mod.rs @@ -3,9 +3,9 @@ mod binding; mod errors; mod expression; mod query_node; +mod sqlparser_util; mod statement; mod tableref; -mod util; use std::sync::Arc; @@ -14,9 +14,9 @@ pub use binding::*; pub use errors::*; pub use expression::*; pub use query_node::*; +pub use sqlparser_util::*; pub use statement::*; pub use tableref::*; -pub use util::*; use crate::main_entry::ClientContext; diff --git a/src/planner_v2/binder/sqlparser_util.rs b/src/planner_v2/binder/sqlparser_util.rs new file mode 100644 index 0000000..837385f --- /dev/null +++ b/src/planner_v2/binder/sqlparser_util.rs @@ -0,0 +1,112 @@ +use sqlparser::ast::{ + ColumnDef, Ident, ObjectName, Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins, + WildcardAdditionalOptions, +}; + +use super::BindError; +use crate::catalog_v2::{ColumnDefinition, DEFAULT_SCHEMA}; + +pub struct SqlparserResolver; + +impl SqlparserResolver { + /// Resolve object_name which is a name of a table, view, custom type, etc., possibly + /// multi-part, i.e. db.schema.obj + pub fn object_name_to_schema_table( + object_name: &ObjectName, + ) -> Result<(String, String), BindError> { + let (schema, table) = match object_name.0.as_slice() { + [table] => (DEFAULT_SCHEMA.to_string(), table.value.clone()), + [schema, table] => (schema.value.clone(), table.value.clone()), + _ => return Err(BindError::SqlParserUnsupportedStmt(object_name.to_string())), + }; + Ok((schema, table)) + } + + pub fn column_def_to_column_definition( + column_def: &ColumnDef, + ) -> Result { + let name = column_def.name.value.clone(); + let ty = column_def.data_type.clone().try_into()?; + Ok(ColumnDefinition::new(name, ty)) + } +} + +#[derive(Default)] +pub struct SqlparserSelectBuilder { + projection: Vec, + from: Vec, +} + +impl SqlparserSelectBuilder { + pub fn projection(mut self, projection: Vec) -> Self { + self.projection = projection; + self + } + + pub fn projection_wildcard(mut self) -> Self { + self.projection = vec![SelectItem::Wildcard(WildcardAdditionalOptions::default())]; + self + } + + pub fn from(mut self, from: Vec) -> Self { + self.from = from; + self + } + + pub fn from_table(mut self, table_name: String) -> Self { + let relation = TableFactor::Table { + name: ObjectName(vec![Ident::new(table_name)]), + alias: None, + args: None, + with_hints: vec![], + }; + let table = TableWithJoins { + relation, + joins: vec![], + }; + self.from = vec![table]; + self + } + + pub fn build(self) -> sqlparser::ast::Select { + sqlparser::ast::Select { + distinct: false, + top: None, + projection: self.projection, + into: None, + from: self.from, + lateral_views: vec![], + selection: None, + group_by: vec![], + cluster_by: vec![], + distribute_by: vec![], + sort_by: vec![], + having: None, + qualify: None, + } + } +} + +pub struct SqlparserQueryBuilder { + body: Box, +} + +impl SqlparserQueryBuilder { + pub fn new_from_select(select: Select) -> Self { + Self { + body: Box::new(SetExpr::Select(Box::new(select))), + } + } + + pub fn build(self) -> Query { + Query { + with: None, + body: self.body, + order_by: vec![], + limit: None, + offset: None, + fetch: None, + lock: None, + } + } +} diff --git a/src/planner_v2/binder/statement/bind_create.rs b/src/planner_v2/binder/statement/bind_create.rs index d066bc8..4f2258c 100644 --- a/src/planner_v2/binder/statement/bind_create.rs +++ b/src/planner_v2/binder/statement/bind_create.rs @@ -1,9 +1,10 @@ use sqlparser::ast::Statement; -use super::{BoundStatement, CreateTableInfo}; +use super::BoundStatement; use crate::catalog_v2::ColumnDefinition; +use crate::common::{CreateInfoBase, CreateTableInfo}; use crate::planner_v2::{ - BindError, Binder, CreateInfoBase, LogicalCreateTable, LogicalOperator, SqlparserResolver, + BindError, Binder, LogicalCreateTable, LogicalOperator, SqlparserResolver, }; use crate::types_v2::LogicalType; diff --git a/src/planner_v2/binder/statement/mod.rs b/src/planner_v2/binder/statement/mod.rs index 25a0905..039e365 100644 --- a/src/planner_v2/binder/statement/mod.rs +++ b/src/planner_v2/binder/statement/mod.rs @@ -2,11 +2,9 @@ mod bind_create; mod bind_explain; mod bind_insert; mod bind_select; -mod create_info; pub use bind_create::*; pub use bind_insert::*; pub use bind_select::*; -pub use create_info::*; use derive_new::new; use sqlparser::ast::Statement; diff --git a/src/planner_v2/binder/tableref/bind_base_table_ref.rs b/src/planner_v2/binder/tableref/bind_base_table_ref.rs index 40882c9..b169cac 100644 --- a/src/planner_v2/binder/tableref/bind_base_table_ref.rs +++ b/src/planner_v2/binder/tableref/bind_base_table_ref.rs @@ -2,6 +2,7 @@ use derive_new::new; use super::BoundTableRef; use crate::catalog_v2::{Catalog, CatalogEntry, TableCatalogEntry}; +use crate::function::{FunctionData, SeqTableScan, SeqTableScanBindInput, TableFunctionBindInput}; use crate::planner_v2::{ BindError, Binder, LogicalGet, LogicalOperator, LogicalOperatorBase, SqlparserResolver, }; @@ -20,13 +21,24 @@ impl Binder { table: sqlparser::ast::TableFactor, ) -> Result { match table { - sqlparser::ast::TableFactor::Table { name, alias, .. } => { + sqlparser::ast::TableFactor::Table { + name, alias, args, .. + } => { let table_index = self.generate_table_index(); let (schema, table) = SqlparserResolver::object_name_to_schema_table(&name)?; let alias = alias .map(|a| a.to_string()) .unwrap_or_else(|| table.clone()); - let table = Catalog::get_table(self.clone_client_context(), schema, table)?; + + if args.is_some() { + todo!("bind table function"); + } + + let table_res = Catalog::get_table(self.clone_client_context(), schema, table); + if table_res.is_err() { + todo!("table could not be found: try to bind a replacement scan"); + } + let table = table_res.unwrap(); let mut return_names = vec![]; let mut return_types = vec![]; @@ -34,10 +46,22 @@ impl Binder { return_names.push(col.name.clone()); return_types.push(col.ty.clone()); } + + let mut bind_data = FunctionData::None; + let seq_table_scan_func = SeqTableScan::get_function(); + if let Some(bind_func) = &seq_table_scan_func.bind { + bind_data = bind_func(TableFunctionBindInput::SeqTableScanBindInput(Box::new( + SeqTableScanBindInput::new(table.clone()), + ))) + .unwrap() + .unwrap(); + } + let logical_get = LogicalGet::new( LogicalOperatorBase::default(), table_index, - table.clone(), + seq_table_scan_func, + bind_data, return_types.clone(), return_names.clone(), ); diff --git a/src/planner_v2/binder/util.rs b/src/planner_v2/binder/util.rs deleted file mode 100644 index 5e59223..0000000 --- a/src/planner_v2/binder/util.rs +++ /dev/null @@ -1,29 +0,0 @@ -use sqlparser::ast::{ColumnDef, ObjectName}; - -use super::BindError; -use crate::catalog_v2::{ColumnDefinition, DEFAULT_SCHEMA}; - -pub struct SqlparserResolver {} - -impl SqlparserResolver { - /// Resolve object_name which is a name of a table, view, custom type, etc., possibly - /// multi-part, i.e. db.schema.obj - pub fn object_name_to_schema_table( - object_name: &ObjectName, - ) -> Result<(String, String), BindError> { - let (schema, table) = match object_name.0.as_slice() { - [table] => (DEFAULT_SCHEMA.to_string(), table.value.clone()), - [schema, table] => (schema.value.clone(), table.value.clone()), - _ => return Err(BindError::SqlParserUnsupportedStmt(object_name.to_string())), - }; - Ok((schema, table)) - } - - pub fn column_def_to_column_definition( - column_def: &ColumnDef, - ) -> Result { - let name = column_def.name.value.clone(); - let ty = column_def.data_type.clone().try_into()?; - Ok(ColumnDefinition::new(name, ty)) - } -} diff --git a/src/planner_v2/operator/logical_get.rs b/src/planner_v2/operator/logical_get.rs index ede1175..92683e1 100644 --- a/src/planner_v2/operator/logical_get.rs +++ b/src/planner_v2/operator/logical_get.rs @@ -1,7 +1,7 @@ use derive_new::new; use super::LogicalOperatorBase; -use crate::catalog_v2::TableCatalogEntry; +use crate::function::{FunctionData, TableFunction}; use crate::types_v2::LogicalType; /// LogicalGet represents a scan operation from a data source @@ -9,8 +9,10 @@ use crate::types_v2::LogicalType; pub struct LogicalGet { pub(crate) base: LogicalOperatorBase, pub(crate) table_idx: usize, - // TODO: migrate to FunctionData when support TableFunction - pub(crate) bind_table: TableCatalogEntry, + /// The function that is called + pub(crate) function: TableFunction, + // The bind data of the function + pub(crate) bind_data: FunctionData, /// The types of ALL columns that can be returned by the table function pub(crate) returned_types: Vec, /// The names of ALL columns that can be returned by the table function diff --git a/src/storage_v2/local_storage.rs b/src/storage_v2/local_storage.rs index 26c8d30..bad5ab6 100644 --- a/src/storage_v2/local_storage.rs +++ b/src/storage_v2/local_storage.rs @@ -38,7 +38,7 @@ impl LocalStorage { } } -#[derive(new)] +#[derive(new, Debug, Clone)] pub struct LocalStorageReader { table: DataTable, #[new(default)] diff --git a/src/util/tree_render.rs b/src/util/tree_render.rs index 2d1895a..496c8b9 100644 --- a/src/util/tree_render.rs +++ b/src/util/tree_render.rs @@ -4,6 +4,7 @@ use derive_new::new; use crate::catalog_v2::ColumnDefinition; use crate::execution::PhysicalOperator; +use crate::function::FunctionData; use crate::planner_v2::{BoundExpression, LogicalOperator}; #[derive(new)] @@ -62,10 +63,17 @@ impl TreeRender { ) } LogicalOperator::LogicalGet(op) => { - format!( - "LogicalGet: {}.{}", - op.bind_table.storage.info.schema, op.bind_table.storage.info.table - ) + let get_table_str = match &op.bind_data { + FunctionData::SeqTableScanInputData(input) => { + format!( + "{}.{}", + input.bind_table.storage.info.schema, + input.bind_table.storage.info.table + ) + } + FunctionData::None => "None".to_string(), + }; + format!("LogicalGet: {}", get_table_str) } LogicalOperator::LogicalProjection(op) => { let exprs = op diff --git a/tests/sqllogictest/src/lib.rs b/tests/sqllogictest/src/lib.rs index efba3de..2b66725 100644 --- a/tests/sqllogictest/src/lib.rs +++ b/tests/sqllogictest/src/lib.rs @@ -65,6 +65,7 @@ impl AsyncDB for DatabaseWrapperV2 { pub fn test_run_v2(sqlfile: &str) { let dbv2 = Arc::new(DatabaseInstance::default()); + dbv2.initialize().unwrap(); let client_context = ClientContext::new(dbv2); let mut tester = Runner::new(DatabaseWrapperV2 { client_context }); tester.run_file(sqlfile).unwrap()