From 5ae40531be1c246382360e93f64631467c489d9b Mon Sep 17 00:00:00 2001 From: Fedomn Date: Tue, 20 Dec 2022 21:04:05 +0800 Subject: [PATCH] feat(planner): support show tables pragma Signed-off-by: Fedomn --- src/catalog_v2/catalog.rs | 6 +- src/catalog_v2/entry/schema_catalog_entry.rs | 10 ++ .../physical_plan/physical_table_scan.rs | 2 +- src/execution/volcano_executor/table_scan.rs | 20 +-- src/function/mod.rs | 30 ++++- src/function/table/mod.rs | 2 + src/function/table/seq_table_scan.rs | 43 +++--- src/function/table/sqlrs_tables.rs | 125 ++++++++++++++++++ src/function/table/table_function.rs | 45 +++---- src/main_entry/db.rs | 12 +- src/planner_v2/binder/bind_context.rs | 11 ++ src/planner_v2/binder/errors.rs | 11 +- .../binder/query_node/plan_select_node.rs | 3 +- src/planner_v2/binder/sqlparser_util.rs | 27 +++- .../binder/statement/bind_show_tables.rs | 21 +++ src/planner_v2/binder/statement/mod.rs | 4 + .../binder/tableref/bind_base_table_ref.rs | 29 ++-- .../binder/tableref/bind_table_function.rs | 79 +++++++++++ src/planner_v2/binder/tableref/mod.rs | 6 + .../binder/tableref/plan_table_function.rs | 11 ++ src/planner_v2/operator/logical_get.rs | 2 +- src/types_v2/values.rs | 120 ++++++++++++++++- src/util/tree_render.rs | 20 +-- tests/slt/pragma.slt | 9 ++ tests/slt/table_function.slt | 9 ++ 25 files changed, 557 insertions(+), 100 deletions(-) create mode 100644 src/function/table/sqlrs_tables.rs create mode 100644 src/planner_v2/binder/statement/bind_show_tables.rs create mode 100644 src/planner_v2/binder/tableref/bind_table_function.rs create mode 100644 src/planner_v2/binder/tableref/plan_table_function.rs create mode 100644 tests/slt/pragma.slt create mode 100644 tests/slt/table_function.slt diff --git a/src/catalog_v2/catalog.rs b/src/catalog_v2/catalog.rs index f71c857..440014e 100644 --- a/src/catalog_v2/catalog.rs +++ b/src/catalog_v2/catalog.rs @@ -83,6 +83,7 @@ impl Catalog { pub fn scan_entries( client_context: Arc, + schema: String, callback: &F, ) -> Result, CatalogError> where @@ -92,7 +93,10 @@ impl Catalog { Ok(c) => c, Err(_) => return Err(CatalogError::CatalogLockedError), }; - Ok(catalog.schemas.scan_entries(callback)) + if let CatalogEntry::SchemaCatalogEntry(entry) = catalog.schemas.get_entry(schema)? { + return Ok(entry.scan_entries(callback)); + } + Err(CatalogError::CatalogEntryTypeNotMatch) } pub fn get_table_function( diff --git a/src/catalog_v2/entry/schema_catalog_entry.rs b/src/catalog_v2/entry/schema_catalog_entry.rs index 0dac53d..276fd53 100644 --- a/src/catalog_v2/entry/schema_catalog_entry.rs +++ b/src/catalog_v2/entry/schema_catalog_entry.rs @@ -66,4 +66,14 @@ impl SchemaCatalogEntry { _ => Err(CatalogError::CatalogEntryNotExists(table_function)), } } + + pub fn scan_entries(&self, callback: &F) -> Vec + where + F: Fn(&CatalogEntry) -> bool, + { + let mut result = vec![]; + result.extend(self.tables.scan_entries(callback)); + result.extend(self.functions.scan_entries(callback)); + result + } } diff --git a/src/execution/physical_plan/physical_table_scan.rs b/src/execution/physical_plan/physical_table_scan.rs index 325a294..4416ce6 100644 --- a/src/execution/physical_plan/physical_table_scan.rs +++ b/src/execution/physical_plan/physical_table_scan.rs @@ -10,7 +10,7 @@ use crate::types_v2::LogicalType; pub struct PhysicalTableScan { pub(crate) base: PhysicalOperatorBase, pub(crate) function: TableFunction, - pub(crate) bind_data: FunctionData, + pub(crate) bind_data: Option, /// The types of ALL columns that can be returned by the table function pub(crate) returned_types: Vec, /// The names of ALL columns that can be returned by the table function diff --git a/src/execution/volcano_executor/table_scan.rs b/src/execution/volcano_executor/table_scan.rs index 261a74f..1387605 100644 --- a/src/execution/volcano_executor/table_scan.rs +++ b/src/execution/volcano_executor/table_scan.rs @@ -5,9 +5,7 @@ use derive_new::new; use futures_async_stream::try_stream; use crate::execution::{ExecutionContext, ExecutorError, PhysicalTableScan, SchemaUtil}; -use crate::function::{ - GlobalTableFunctionState, SeqTableScanInitInput, TableFunctionInitInput, TableFunctionInput, -}; +use crate::function::TableFunctionInput; #[derive(new)] pub struct TableScan { @@ -21,22 +19,16 @@ impl TableScan { let bind_data = self.plan.bind_data; - let table_scan_func = self.plan.function.function; - let global_state = if let Some(init_global_func) = self.plan.function.init_global { - let seq_table_scan_init_input = TableFunctionInitInput::SeqTableScanInitInput( - Box::new(SeqTableScanInitInput::new(bind_data.clone())), - ); - init_global_func(context.clone_client_context(), seq_table_scan_init_input)? - } else { - GlobalTableFunctionState::None - }; + let function = self.plan.function; + let table_scan_func = function.function; + let mut tabel_scan_input = TableFunctionInput::new(bind_data); - let mut tabel_scan_input = TableFunctionInput::new(bind_data, global_state); while let Some(batch) = table_scan_func(context.clone_client_context(), &mut tabel_scan_input)? { let columns = batch.columns().to_vec(); - yield RecordBatch::try_new(schema.clone(), columns)? + let try_new = RecordBatch::try_new(schema.clone(), columns)?; + yield try_new } } } diff --git a/src/function/mod.rs b/src/function/mod.rs index b0e4d4a..ccb6afe 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -1,11 +1,39 @@ mod errors; mod table; +use std::sync::Arc; + +use derive_new::new; pub use errors::*; pub use table::*; +use crate::catalog_v2::{Catalog, DEFAULT_SCHEMA}; +use crate::common::{CreateInfoBase, CreateTableFunctionInfo}; +use crate::main_entry::ClientContext; + #[derive(Debug, Clone)] pub enum FunctionData { SeqTableScanInputData(Box), - None, + SqlrsTablesData(Box), + Placeholder, +} + +#[derive(new)] +pub struct BuiltinFunctions { + pub(crate) context: Arc, +} + +impl BuiltinFunctions { + pub fn add_table_functions(&mut self, function: TableFunction) -> Result<(), FunctionError> { + let info = CreateTableFunctionInfo::new( + CreateInfoBase::new(DEFAULT_SCHEMA.to_string()), + function.name.clone(), + vec![function], + ); + Ok(Catalog::create_table_function(self.context.clone(), info)?) + } + + pub fn initialize(&mut self) -> Result<(), FunctionError> { + SqlrsTablesFunc::register_function(self) + } } diff --git a/src/function/table/mod.rs b/src/function/table/mod.rs index c61bd1b..5a4798c 100644 --- a/src/function/table/mod.rs +++ b/src/function/table/mod.rs @@ -1,4 +1,6 @@ mod seq_table_scan; +mod sqlrs_tables; mod table_function; pub use seq_table_scan::*; +pub use sqlrs_tables::*; pub use table_function::*; diff --git a/src/function/table/seq_table_scan.rs b/src/function/table/seq_table_scan.rs index 22eee15..b04ad7f 100644 --- a/src/function/table/seq_table_scan.rs +++ b/src/function/table/seq_table_scan.rs @@ -8,6 +8,7 @@ use crate::catalog_v2::TableCatalogEntry; use crate::function::{FunctionData, FunctionError}; use crate::main_entry::ClientContext; use crate::storage_v2::{LocalStorage, LocalStorageReader}; +use crate::types_v2::LogicalType; /// The table scan function represents a sequential scan over one of base tables. pub struct SeqTableScan; @@ -18,23 +19,15 @@ pub struct SeqTableScanInputData { pub(crate) local_storage_reader: LocalStorageReader, } -#[derive(new)] -pub struct SeqTableScanBindInput { - pub(crate) bind_table: TableCatalogEntry, -} - -#[derive(new)] -pub struct SeqTableScanInitInput { - #[allow(dead_code)] - pub(crate) bind_data: FunctionData, -} - impl SeqTableScan { - fn seq_table_scan_bind_func( + #[allow(clippy::ptr_arg)] + fn bind_func( + _context: Arc, input: TableFunctionBindInput, + _return_types: &mut Vec, + _return_names: &mut Vec, ) -> Result, FunctionError> { - if let TableFunctionBindInput::SeqTableScanBindInput(bind_input) = input { - let table = bind_input.bind_table; + if let Some(table) = input.bind_table { let res = FunctionData::SeqTableScanInputData(Box::new(SeqTableScanInputData::new( table.clone(), LocalStorage::create_reader(&table.storage), @@ -47,26 +40,28 @@ impl SeqTableScan { } } - fn seq_table_scan_func( + fn scan_func( context: Arc, input: &mut TableFunctionInput, ) -> Result, FunctionError> { - if let FunctionData::SeqTableScanInputData(data) = &mut input.bind_data { - let batch = data.local_storage_reader.next_batch(context); - Ok(batch) + if let Some(bind_data) = &mut input.bind_data { + if let FunctionData::SeqTableScanInputData(data) = bind_data { + Ok(data.local_storage_reader.next_batch(context)) + } else { + Err(FunctionError::InternalError( + "unexpected bind data type".to_string(), + )) + } } else { - Err(FunctionError::InternalError( - "unexpected bind data type".to_string(), - )) + Ok(None) } } pub fn get_function() -> TableFunction { TableFunction::new( "seq_table_scan".to_string(), - Some(Self::seq_table_scan_bind_func), - None, - Self::seq_table_scan_func, + Some(Self::bind_func), + Self::scan_func, ) } } diff --git a/src/function/table/sqlrs_tables.rs b/src/function/table/sqlrs_tables.rs new file mode 100644 index 0000000..3fc506f --- /dev/null +++ b/src/function/table/sqlrs_tables.rs @@ -0,0 +1,125 @@ +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; +use derive_new::new; + +use super::{TableFunction, TableFunctionBindInput, TableFunctionInput}; +use crate::catalog_v2::{Catalog, CatalogEntry, DEFAULT_SCHEMA}; +use crate::execution::SchemaUtil; +use crate::function::{BuiltinFunctions, FunctionData, FunctionError}; +use crate::main_entry::ClientContext; +use crate::types_v2::{LogicalType, ScalarValue}; + +pub struct SqlrsTablesFunc; + +#[derive(new, Debug, Clone)] +pub struct SqlrsTablesData { + pub(crate) entries: Vec, + pub(crate) return_types: Vec, + pub(crate) return_names: Vec, + pub(crate) current_cursor: usize, +} + +impl SqlrsTablesFunc { + fn generate_sqlrs_tables_names() -> Vec { + vec![ + "schema_name".to_string(), + "schema_oid".to_string(), + "table_name".to_string(), + "table_oid".to_string(), + ] + } + + fn generate_sqlrs_tables_types() -> Vec { + vec![ + LogicalType::Varchar, + LogicalType::Integer, + LogicalType::Varchar, + LogicalType::Integer, + ] + } + + fn bind_func( + context: Arc, + _input: TableFunctionBindInput, + return_types: &mut Vec, + return_names: &mut Vec, + ) -> Result, FunctionError> { + let entries = Catalog::scan_entries(context, DEFAULT_SCHEMA.to_string(), &|entry| { + matches!(entry, CatalogEntry::TableCatalogEntry(_)) + })?; + let data = SqlrsTablesData::new( + entries, + Self::generate_sqlrs_tables_types(), + Self::generate_sqlrs_tables_names(), + 0, + ); + return_types.extend(data.return_types.clone()); + return_names.extend(data.return_names.clone()); + Ok(Some(FunctionData::SqlrsTablesData(Box::new(data)))) + } + + fn tables_func( + _context: Arc, + input: &mut TableFunctionInput, + ) -> Result, FunctionError> { + if input.bind_data.is_none() { + return Ok(None); + } + + let bind_data = input.bind_data.as_mut().unwrap(); + if let FunctionData::SqlrsTablesData(data) = bind_data { + if data.current_cursor >= data.entries.len() { + return Ok(None); + } + + let schema = SchemaUtil::new_schema_ref(&data.return_names, &data.return_types); + let mut schema_names = ScalarValue::new_builder(&LogicalType::Varchar)?; + let mut schema_oids = ScalarValue::new_builder(&LogicalType::Integer)?; + let mut table_names = ScalarValue::new_builder(&LogicalType::Varchar)?; + let mut table_oids = ScalarValue::new_builder(&LogicalType::Integer)?; + for entry in data.entries.iter() { + if let CatalogEntry::TableCatalogEntry(table) = entry { + ScalarValue::append_for_builder( + &ScalarValue::Utf8(Some(table.schema_base.name.clone())), + &mut schema_names, + )?; + ScalarValue::append_for_builder( + &ScalarValue::Int32(Some(table.schema_base.oid as i32)), + &mut schema_oids, + )?; + ScalarValue::append_for_builder( + &ScalarValue::Utf8(Some(table.base.name.clone())), + &mut table_names, + )?; + ScalarValue::append_for_builder( + &ScalarValue::Int32(Some(table.base.oid as i32)), + &mut table_oids, + )?; + } + } + let cols = vec![ + schema_names.finish(), + schema_oids.finish(), + table_names.finish(), + table_oids.finish(), + ]; + data.current_cursor += data.entries.len(); + let batch = RecordBatch::try_new(schema, cols)?; + Ok(Some(batch)) + } else { + Err(FunctionError::InternalError( + "unexpected global state type".to_string(), + )) + } + } + + pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> { + set.add_table_functions(TableFunction::new( + "sqlrs_tables".to_string(), + Some(Self::bind_func), + Self::tables_func, + ))?; + Ok(()) + } +} diff --git a/src/function/table/table_function.rs b/src/function/table/table_function.rs index dd2f72d..c85e6b5 100644 --- a/src/function/table/table_function.rs +++ b/src/function/table/table_function.rs @@ -3,43 +3,35 @@ use std::sync::Arc; use arrow::record_batch::RecordBatch; use derive_new::new; +use sqlparser::ast::FunctionArg; -use super::{SeqTableScanBindInput, SeqTableScanInitInput}; +use crate::catalog_v2::TableCatalogEntry; use crate::function::{FunctionData, FunctionError}; use crate::main_entry::ClientContext; +use crate::types_v2::LogicalType; -pub enum GlobalTableFunctionState { - None, -} - -pub enum TableFunctionBindInput { - SeqTableScanBindInput(Box), - None, -} - -#[derive(new)] -pub struct TableFunctionInput { - pub(crate) bind_data: FunctionData, +#[derive(new, Default)] +pub struct TableFunctionBindInput { + pub(crate) bind_table: Option, #[allow(dead_code)] - pub(crate) global_state: GlobalTableFunctionState, + pub(crate) func_args: Option>, } -pub enum TableFunctionInitInput { - SeqTableScanInitInput(Box), - None, +#[derive(new, Default)] +pub struct TableFunctionInput { + pub(crate) bind_data: Option, } -pub type TableFunctionBindFunc = - fn(TableFunctionBindInput) -> Result, FunctionError>; +pub type TableFunctionBindFunc = fn( + Arc, + TableFunctionBindInput, + &mut Vec, + &mut Vec, +) -> Result, FunctionError>; pub type TableFunc = fn(Arc, &mut TableFunctionInput) -> Result, FunctionError>; -pub type TableFunctionInitGlobalFunc = fn( - Arc, - TableFunctionInitInput, -) -> Result; - #[derive(new, Clone)] pub struct TableFunction { // The name of the function @@ -49,11 +41,6 @@ pub struct TableFunction { /// returning bind data The returned FunctionData object should be constant and should not /// be changed during execution. pub(crate) bind: Option, - /// (Optional) global init function - /// Initialize the global operator state of the function. - /// The global operator state is used to keep track of the progress in the table function and - /// is shared between all threads working on the table function. - pub(crate) init_global: Option, /// The main function pub(crate) function: TableFunc, } diff --git a/src/main_entry/db.rs b/src/main_entry/db.rs index d53cb71..5a7ff0e 100644 --- a/src/main_entry/db.rs +++ b/src/main_entry/db.rs @@ -1,7 +1,8 @@ use std::sync::{Arc, RwLock}; -use super::DatabaseError; +use super::{ClientContext, DatabaseError}; use crate::catalog_v2::{Catalog, CatalogError, DEFAULT_SCHEMA}; +use crate::function::BuiltinFunctions; use crate::storage_v2::LocalStorage; #[derive(Default)] @@ -14,6 +15,8 @@ impl DatabaseInstance { pub fn initialize(self: &Arc) -> Result<(), DatabaseError> { // Create the default schema: main self.init_default_schema()?; + // Initialize the builtin functions + self.init_builtin_functions()?; Ok(()) } @@ -29,4 +32,11 @@ impl DatabaseInstance { catalog.create_schema(DEFAULT_SCHEMA.to_string()).unwrap(); Ok(()) } + + fn init_builtin_functions(self: &Arc) -> Result<(), DatabaseError> { + let context = ClientContext::new(self.clone()); + let mut buildin_funcs = BuiltinFunctions::new(context); + buildin_funcs.initialize()?; + Ok(()) + } } diff --git a/src/planner_v2/binder/bind_context.rs b/src/planner_v2/binder/bind_context.rs index 8e5e3df..7f7e7b4 100644 --- a/src/planner_v2/binder/bind_context.rs +++ b/src/planner_v2/binder/bind_context.rs @@ -58,6 +58,17 @@ impl BindContext { self.add_binding(alias, index, types, names, Some(catalog_entry)); } + pub fn add_table_function( + &mut self, + alias: String, + index: usize, + types: Vec, + names: Vec, + catalog_entry: CatalogEntry, + ) { + self.add_binding(alias, index, types, names, Some(catalog_entry)); + } + pub fn get_binding(&self, table_name: &str) -> Option { self.bindings.get(table_name).cloned() } diff --git a/src/planner_v2/binder/errors.rs b/src/planner_v2/binder/errors.rs index 83075bc..8cf6922 100644 --- a/src/planner_v2/binder/errors.rs +++ b/src/planner_v2/binder/errors.rs @@ -1,3 +1,6 @@ +use crate::catalog_v2::CatalogError; +use crate::function::FunctionError; + #[derive(thiserror::Error, Debug)] pub enum BindError { #[error("unsupported expr: {0}")] @@ -18,6 +21,12 @@ pub enum BindError { CatalogError( #[from] #[source] - crate::catalog_v2::CatalogError, + CatalogError, + ), + #[error("function error: {0}")] + FunctionError( + #[from] + #[source] + FunctionError, ), } diff --git a/src/planner_v2/binder/query_node/plan_select_node.rs b/src/planner_v2/binder/query_node/plan_select_node.rs index f2f1738..8f23324 100644 --- a/src/planner_v2/binder/query_node/plan_select_node.rs +++ b/src/planner_v2/binder/query_node/plan_select_node.rs @@ -1,6 +1,6 @@ use super::BoundSelectNode; use crate::planner_v2::BoundTableRef::{ - BoundBaseTableRef, BoundDummyTableRef, BoundExpressionListRef, + BoundBaseTableRef, BoundDummyTableRef, BoundExpressionListRef, BoundTableFunction, }; use crate::planner_v2::{ BindError, Binder, BoundCastExpression, BoundStatement, LogicalOperator, LogicalOperatorBase, @@ -19,6 +19,7 @@ impl Binder { } BoundBaseTableRef(bound_ref) => self.create_plan_for_base_tabel_ref(*bound_ref)?, BoundDummyTableRef(bound_ref) => self.create_plan_for_dummy_table_ref(bound_ref)?, + BoundTableFunction(bound_func) => self.create_plan_for_table_function(*bound_func)?, }; let root = LogicalOperator::LogicalProjection(LogicalProjection::new( diff --git a/src/planner_v2/binder/sqlparser_util.rs b/src/planner_v2/binder/sqlparser_util.rs index 837385f..b2fe757 100644 --- a/src/planner_v2/binder/sqlparser_util.rs +++ b/src/planner_v2/binder/sqlparser_util.rs @@ -1,6 +1,6 @@ use sqlparser::ast::{ - ColumnDef, Ident, ObjectName, Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins, - WildcardAdditionalOptions, + ColumnDef, Expr, Ident, ObjectName, Query, Select, SelectItem, SetExpr, TableFactor, + TableWithJoins, WildcardAdditionalOptions, }; use super::BindError; @@ -43,6 +43,14 @@ impl SqlparserSelectBuilder { self } + pub fn projection_cols(mut self, cols: Vec<&str>) -> Self { + self.projection = cols + .into_iter() + .map(|col| SelectItem::UnnamedExpr(Expr::Identifier(Ident::new(col)))) + .collect(); + self + } + pub fn projection_wildcard(mut self) -> Self { self.projection = vec![SelectItem::Wildcard(WildcardAdditionalOptions::default())]; self @@ -68,6 +76,21 @@ impl SqlparserSelectBuilder { self } + pub fn from_table_function(mut self, table_function_name: &str) -> Self { + let relation = TableFactor::Table { + name: ObjectName(vec![Ident::new(table_function_name)]), + alias: None, + args: Some(vec![]), + with_hints: vec![], + }; + let table = TableWithJoins { + relation, + joins: vec![], + }; + self.from = vec![table]; + self + } + pub fn build(self) -> sqlparser::ast::Select { sqlparser::ast::Select { distinct: false, diff --git a/src/planner_v2/binder/statement/bind_show_tables.rs b/src/planner_v2/binder/statement/bind_show_tables.rs new file mode 100644 index 0000000..22fcdd2 --- /dev/null +++ b/src/planner_v2/binder/statement/bind_show_tables.rs @@ -0,0 +1,21 @@ +use sqlparser::ast::Statement; + +use super::BoundStatement; +use crate::planner_v2::{BindError, Binder, SqlparserQueryBuilder, SqlparserSelectBuilder}; + +impl Binder { + pub fn bind_show_tables(&mut self, stmt: &Statement) -> Result { + match stmt { + Statement::ShowTables { .. } => { + let select = SqlparserSelectBuilder::default() + .projection_cols(vec!["schema_name", "table_name"]) + .from_table_function("sqlrs_tables") + .build(); + let query = SqlparserQueryBuilder::new_from_select(select).build(); + let node = self.bind_select_node(&query)?; + self.create_plan_for_select_node(node) + } + _ => Err(BindError::UnsupportedStmt(format!("{:?}", stmt))), + } + } +} diff --git a/src/planner_v2/binder/statement/mod.rs b/src/planner_v2/binder/statement/mod.rs index 039e365..8be8e37 100644 --- a/src/planner_v2/binder/statement/mod.rs +++ b/src/planner_v2/binder/statement/mod.rs @@ -2,9 +2,12 @@ mod bind_create; mod bind_explain; mod bind_insert; mod bind_select; +mod bind_show_tables; + pub use bind_create::*; pub use bind_insert::*; pub use bind_select::*; +pub use bind_show_tables::*; use derive_new::new; use sqlparser::ast::Statement; @@ -26,6 +29,7 @@ impl Binder { Statement::Insert { .. } => self.bind_insert(statement), Statement::Query { .. } => self.bind_select(statement), Statement::Explain { .. } => self.bind_explain(statement), + Statement::ShowTables { .. } => self.bind_show_tables(statement), _ => Err(BindError::UnsupportedStmt(format!("{:?}", statement))), } } diff --git a/src/planner_v2/binder/tableref/bind_base_table_ref.rs b/src/planner_v2/binder/tableref/bind_base_table_ref.rs index b169cac..58b4855 100644 --- a/src/planner_v2/binder/tableref/bind_base_table_ref.rs +++ b/src/planner_v2/binder/tableref/bind_base_table_ref.rs @@ -2,7 +2,7 @@ use derive_new::new; use super::BoundTableRef; use crate::catalog_v2::{Catalog, CatalogEntry, TableCatalogEntry}; -use crate::function::{FunctionData, SeqTableScan, SeqTableScanBindInput, TableFunctionBindInput}; +use crate::function::{SeqTableScan, TableFunctionBindInput}; use crate::planner_v2::{ BindError, Binder, LogicalGet, LogicalOperator, LogicalOperatorBase, SqlparserResolver, }; @@ -20,23 +20,22 @@ impl Binder { &mut self, table: sqlparser::ast::TableFactor, ) -> Result { - match table { + match table.clone() { sqlparser::ast::TableFactor::Table { name, alias, args, .. } => { - let table_index = self.generate_table_index(); + if args.is_some() { + return self.bind_table_function(table); + } let (schema, table) = SqlparserResolver::object_name_to_schema_table(&name)?; let alias = alias .map(|a| a.to_string()) .unwrap_or_else(|| table.clone()); - if args.is_some() { - todo!("bind table function"); - } - let table_res = Catalog::get_table(self.clone_client_context(), schema, table); if table_res.is_err() { - todo!("table could not be found: try to bind a replacement scan"); + // table could not be found: try to bind a replacement scan + return Err(BindError::CatalogError(table_res.err().unwrap())); } let table = table_res.unwrap(); @@ -47,16 +46,18 @@ impl Binder { return_types.push(col.ty.clone()); } - let mut bind_data = FunctionData::None; + let mut bind_data = None; let seq_table_scan_func = SeqTableScan::get_function(); if let Some(bind_func) = &seq_table_scan_func.bind { - bind_data = bind_func(TableFunctionBindInput::SeqTableScanBindInput(Box::new( - SeqTableScanBindInput::new(table.clone()), - ))) - .unwrap() - .unwrap(); + bind_data = bind_func( + self.clone_client_context(), + TableFunctionBindInput::new(Some(table.clone()), None), + &mut vec![], + &mut vec![], + )?; } + let table_index = self.generate_table_index(); let logical_get = LogicalGet::new( LogicalOperatorBase::default(), table_index, diff --git a/src/planner_v2/binder/tableref/bind_table_function.rs b/src/planner_v2/binder/tableref/bind_table_function.rs new file mode 100644 index 0000000..b3917db --- /dev/null +++ b/src/planner_v2/binder/tableref/bind_table_function.rs @@ -0,0 +1,79 @@ +use derive_new::new; + +use super::BoundTableRef; +use crate::catalog_v2::{Catalog, CatalogEntry}; +use crate::function::TableFunctionBindInput; +use crate::planner_v2::{ + BindError, Binder, LogicalGet, LogicalOperator, LogicalOperatorBase, SqlparserResolver, +}; + +/// Represents a reference to a table-producing function call +#[derive(new, Debug)] +pub struct BoundTableFunction { + pub(crate) get: LogicalOperator, +} + +impl Binder { + pub fn bind_table_function( + &mut self, + table: sqlparser::ast::TableFactor, + ) -> Result { + match table { + sqlparser::ast::TableFactor::Table { + name, alias, args, .. + } => { + let (schema, table_function_name) = + SqlparserResolver::object_name_to_schema_table(&name)?; + let alias = alias + .map(|a| a.to_string()) + .unwrap_or_else(|| table_function_name.clone()); + + let function = Catalog::get_table_function( + self.clone_client_context(), + schema, + table_function_name, + )?; + + let table_func = function.functions[0].clone(); + let mut return_types = vec![]; + let mut return_names = vec![]; + let bind_data = if let Some(bind_func) = table_func.bind { + bind_func( + self.clone_client_context(), + TableFunctionBindInput::new(None, args), + &mut return_types, + &mut return_names, + )? + } else { + None + }; + + let table_index = self.generate_table_index(); + let logical_get = LogicalGet::new( + LogicalOperatorBase::default(), + table_index, + table_func, + bind_data, + return_types.clone(), + return_names.clone(), + ); + let plan = LogicalOperator::LogicalGet(logical_get); + // now add the table function to the bind context so its columns can be bound + self.bind_context.add_table_function( + alias, + table_index, + return_types, + return_names, + CatalogEntry::TableFunctionCatalogEntry(function), + ); + let bound_ref = + BoundTableRef::BoundTableFunction(Box::new(BoundTableFunction::new(plan))); + Ok(bound_ref) + } + other => Err(BindError::Internal(format!( + "unexpected table type: {}, only bind TableFactor::Table", + other + ))), + } + } +} diff --git a/src/planner_v2/binder/tableref/mod.rs b/src/planner_v2/binder/tableref/mod.rs index 4cd88f4..9aad3e5 100644 --- a/src/planner_v2/binder/tableref/mod.rs +++ b/src/planner_v2/binder/tableref/mod.rs @@ -1,15 +1,20 @@ mod bind_base_table_ref; mod bind_dummy_table_ref; mod bind_expression_list_ref; +mod bind_table_function; mod plan_base_table_ref; mod plan_dummy_table_ref; mod plan_expression_list_ref; +mod plan_table_function; + pub use bind_base_table_ref::*; pub use bind_dummy_table_ref::*; pub use bind_expression_list_ref::*; +pub use bind_table_function::*; pub use plan_base_table_ref::*; pub use plan_dummy_table_ref::*; pub use plan_expression_list_ref::*; +pub use plan_table_function::*; use super::{BindError, Binder}; @@ -18,6 +23,7 @@ pub enum BoundTableRef { BoundExpressionListRef(BoundExpressionListRef), BoundBaseTableRef(Box), BoundDummyTableRef(BoundDummyTableRef), + BoundTableFunction(Box), } impl Binder { diff --git a/src/planner_v2/binder/tableref/plan_table_function.rs b/src/planner_v2/binder/tableref/plan_table_function.rs new file mode 100644 index 0000000..e767abc --- /dev/null +++ b/src/planner_v2/binder/tableref/plan_table_function.rs @@ -0,0 +1,11 @@ +use super::BoundTableFunction; +use crate::planner_v2::{BindError, Binder, LogicalOperator}; + +impl Binder { + pub fn create_plan_for_table_function( + &mut self, + bound_ref: BoundTableFunction, + ) -> Result { + Ok(bound_ref.get) + } +} diff --git a/src/planner_v2/operator/logical_get.rs b/src/planner_v2/operator/logical_get.rs index 92683e1..c5a2c71 100644 --- a/src/planner_v2/operator/logical_get.rs +++ b/src/planner_v2/operator/logical_get.rs @@ -12,7 +12,7 @@ pub struct LogicalGet { /// The function that is called pub(crate) function: TableFunction, // The bind data of the function - pub(crate) bind_data: FunctionData, + pub(crate) bind_data: Option, /// The types of ALL columns that can be returned by the table function pub(crate) returned_types: Vec, /// The names of ALL columns that can be returned by the table function diff --git a/src/types_v2/values.rs b/src/types_v2/values.rs index d98e6fd..07eba7e 100644 --- a/src/types_v2/values.rs +++ b/src/types_v2/values.rs @@ -5,8 +5,11 @@ use std::iter::repeat; use std::sync::Arc; use arrow::array::{ - new_null_array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + new_null_array, ArrayBuilder, ArrayRef, BooleanArray, BooleanBuilder, Float32Array, + Float32Builder, Float64Array, Float64Builder, Int16Array, Int16Builder, Int32Array, + Int32Builder, Int64Array, Int64Builder, Int8Array, Int8Builder, StringArray, StringBuilder, + UInt16Array, UInt16Builder, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, UInt8Array, + UInt8Builder, }; use arrow::datatypes::DataType; use ordered_float::OrderedFloat; @@ -281,6 +284,119 @@ impl ScalarValue { ScalarValue::Null => new_null_array(&DataType::Null, size), } } + + pub fn new_builder(data_type: &LogicalType) -> Result, TypeError> { + match data_type { + LogicalType::Invalid | LogicalType::SqlNull => Err(TypeError::InternalError(format!( + "Unsupported type {:?} for builder", + data_type + ))), + LogicalType::Boolean => Ok(Box::new(BooleanBuilder::new())), + LogicalType::Tinyint => Ok(Box::new(Int8Builder::new())), + LogicalType::UTinyint => Ok(Box::new(UInt8Builder::new())), + LogicalType::Smallint => Ok(Box::new(Int16Builder::new())), + LogicalType::USmallint => Ok(Box::new(UInt16Builder::new())), + LogicalType::Integer => Ok(Box::new(Int32Builder::new())), + LogicalType::UInteger => Ok(Box::new(UInt32Builder::new())), + LogicalType::Bigint => Ok(Box::new(Int64Builder::new())), + LogicalType::UBigint => Ok(Box::new(UInt64Builder::new())), + LogicalType::Float => Ok(Box::new(Float32Builder::new())), + LogicalType::Double => Ok(Box::new(Float64Builder::new())), + LogicalType::Varchar => Ok(Box::new(StringBuilder::new())), + } + } + + pub fn append_for_builder( + value: &ScalarValue, + builder: &mut Box, + ) -> Result<(), TypeError> { + match value { + ScalarValue::Null => { + return Err(TypeError::InternalError( + "Unsupported type: Null for builder".to_string(), + )) + } + ScalarValue::Boolean(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::Utf8(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(v.as_ref()), + ScalarValue::Int8(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::Int16(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::Int32(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::Int64(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::UInt8(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::UInt16(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::UInt32(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::UInt64(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::Float32(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::Float64(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + } + Ok(()) + } + + pub fn get_datatype(&self) -> DataType { + match self { + ScalarValue::Boolean(_) => DataType::Boolean, + ScalarValue::UInt8(_) => DataType::UInt8, + ScalarValue::UInt16(_) => DataType::UInt16, + ScalarValue::UInt32(_) => DataType::UInt32, + ScalarValue::UInt64(_) => DataType::UInt64, + ScalarValue::Int8(_) => DataType::Int8, + ScalarValue::Int16(_) => DataType::Int16, + ScalarValue::Int32(_) => DataType::Int32, + ScalarValue::Int64(_) => DataType::Int64, + ScalarValue::Float32(_) => DataType::Float32, + ScalarValue::Float64(_) => DataType::Float64, + ScalarValue::Utf8(_) => DataType::Utf8, + ScalarValue::Null => DataType::Null, + } + } } macro_rules! impl_scalar { diff --git a/src/util/tree_render.rs b/src/util/tree_render.rs index 496c8b9..d655fb1 100644 --- a/src/util/tree_render.rs +++ b/src/util/tree_render.rs @@ -64,14 +64,18 @@ impl TreeRender { } LogicalOperator::LogicalGet(op) => { let get_table_str = match &op.bind_data { - FunctionData::SeqTableScanInputData(input) => { - format!( - "{}.{}", - input.bind_table.storage.info.schema, - input.bind_table.storage.info.table - ) - } - FunctionData::None => "None".to_string(), + Some(data) => match data { + FunctionData::SeqTableScanInputData(input) => { + format!( + "{}.{}", + input.bind_table.storage.info.schema, + input.bind_table.storage.info.table + ) + } + FunctionData::Placeholder => todo!(), + FunctionData::SqlrsTablesData(_) => "sqlrs_tables".to_string(), + }, + None => "None".to_string(), }; format!("LogicalGet: {}", get_table_str) } diff --git a/tests/slt/pragma.slt b/tests/slt/pragma.slt new file mode 100644 index 0000000..f046148 --- /dev/null +++ b/tests/slt/pragma.slt @@ -0,0 +1,9 @@ +onlyif sqlrs_v2 +statement ok +create table t1(v1 int, v2 int, v3 int); + +onlyif sqlrs_v2 +query II +show tables +---- +main t1 diff --git a/tests/slt/table_function.slt b/tests/slt/table_function.slt new file mode 100644 index 0000000..25fdd19 --- /dev/null +++ b/tests/slt/table_function.slt @@ -0,0 +1,9 @@ +onlyif sqlrs_v2 +statement ok +create table t1(v1 int, v2 int, v3 int); + +onlyif sqlrs_v2 +query III +select schema_name, schema_oid, table_name from sqlrs_tables(); +---- +main 1 t1