Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions src/catalog_v2/catalog.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::sync::Arc;

use super::entry::{CatalogEntry, DataTable};
use super::{CatalogError, CatalogSet, TableCatalogEntry, TableFunctionCatalogEntry};
use crate::common::CreateTableFunctionInfo;
use super::{
CatalogError, CatalogSet, ScalarFunctionCatalogEntry, TableCatalogEntry,
TableFunctionCatalogEntry,
};
use crate::common::{CreateScalarFunctionInfo, CreateTableFunctionInfo};
use crate::main_entry::ClientContext;

/// The Catalog object represents the catalog of the database.
Expand Down Expand Up @@ -113,4 +116,39 @@ impl Catalog {
}
Err(CatalogError::CatalogEntryTypeNotMatch)
}

pub fn create_scalar_function(
client_context: Arc<ClientContext>,
info: CreateScalarFunctionInfo,
) -> Result<(), CatalogError> {
let mut catalog = match client_context.db.catalog.try_write() {
Ok(c) => c,
Err(_) => return Err(CatalogError::CatalogLockedError),
};
let version = catalog.catalog_version;
let entry = catalog.schemas.get_mut_entry(info.base.schema.clone())?;

if let CatalogEntry::SchemaCatalogEntry(mut_entry) = entry {
mut_entry.create_scalar_function(version + 1, info)?;
catalog.catalog_version += 1;
Ok(())
} else {
Err(CatalogError::CatalogEntryTypeNotMatch)
}
}

pub fn get_scalar_function(
client_context: Arc<ClientContext>,
schema: String,
scalar_function: String,
) -> Result<ScalarFunctionCatalogEntry, CatalogError> {
let catalog = match client_context.db.catalog.try_read() {
Ok(c) => c,
Err(_) => return Err(CatalogError::CatalogLockedError),
};
if let CatalogEntry::SchemaCatalogEntry(entry) = catalog.schemas.get_entry(schema)? {
return entry.get_scalar_function(scalar_function);
}
Err(CatalogError::CatalogEntryTypeNotMatch)
}
}
9 changes: 9 additions & 0 deletions src/catalog_v2/catalog_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ impl CatalogSet {
Err(CatalogError::CatalogEntryNotExists(name))
}

pub fn get_mut_entry(&mut self, name: String) -> Result<&mut CatalogEntry, CatalogError> {
if let Some(index) = self.mapping.get(&name) {
if let Some(entry) = self.entries.get_mut(index) {
return Ok(entry);
}
}
Err(CatalogError::CatalogEntryNotExists(name))
}

pub fn replace_entry(
&mut self,
name: String,
Expand Down
3 changes: 3 additions & 0 deletions src/catalog_v2/entry/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod scalar_function_catalog_entry;
mod schema_catalog_entry;
mod table_catalog_entry;
mod table_function_catalog_entry;

use derive_new::new;
pub use scalar_function_catalog_entry::*;
pub use schema_catalog_entry::*;
pub use table_catalog_entry::*;
pub use table_function_catalog_entry::*;
Expand All @@ -12,6 +14,7 @@ pub enum CatalogEntry {
SchemaCatalogEntry(SchemaCatalogEntry),
TableCatalogEntry(TableCatalogEntry),
TableFunctionCatalogEntry(TableFunctionCatalogEntry),
ScalarFunctionCatalogEntry(ScalarFunctionCatalogEntry),
}

impl CatalogEntry {
Expand Down
12 changes: 12 additions & 0 deletions src/catalog_v2/entry/scalar_function_catalog_entry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use derive_new::new;

use super::CatalogEntryBase;
use crate::function::ScalarFunction;

#[derive(new, Clone, Debug)]
pub struct ScalarFunctionCatalogEntry {
#[allow(dead_code)]
pub(crate) base: CatalogEntryBase,
#[allow(dead_code)]
pub(crate) functions: Vec<ScalarFunction>,
}
30 changes: 28 additions & 2 deletions src/catalog_v2/entry/schema_catalog_entry.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use super::table_catalog_entry::{DataTable, TableCatalogEntry};
use super::{CatalogEntry, CatalogEntryBase, TableFunctionCatalogEntry};
use super::{
CatalogEntry, CatalogEntryBase, ScalarFunctionCatalogEntry, TableFunctionCatalogEntry,
};
use crate::catalog_v2::{CatalogError, CatalogSet};
use crate::common::CreateTableFunctionInfo;
use crate::common::{CreateScalarFunctionInfo, CreateTableFunctionInfo};

#[allow(dead_code)]
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -76,4 +78,28 @@ impl SchemaCatalogEntry {
result.extend(self.functions.scan_entries(callback));
result
}

pub fn create_scalar_function(
&mut self,
oid: usize,
info: CreateScalarFunctionInfo,
) -> Result<(), CatalogError> {
let entry = ScalarFunctionCatalogEntry::new(
CatalogEntryBase::new(oid, info.name.clone()),
info.functions,
);
let entry = CatalogEntry::ScalarFunctionCatalogEntry(entry);
self.functions.create_entry(info.name, entry)?;
Ok(())
}

pub fn get_scalar_function(
&self,
scalar_function: String,
) -> Result<ScalarFunctionCatalogEntry, CatalogError> {
match self.functions.get_entry(scalar_function.clone())? {
CatalogEntry::ScalarFunctionCatalogEntry(e) => Ok(e),
_ => Err(CatalogError::CatalogEntryNotExists(scalar_function)),
}
}
}
11 changes: 10 additions & 1 deletion src/common/create_info.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use derive_new::new;

use crate::catalog_v2::ColumnDefinition;
use crate::function::TableFunction;
use crate::function::{ScalarFunction, TableFunction};

#[derive(new, Debug, Clone)]
pub struct CreateInfoBase {
Expand All @@ -25,3 +25,12 @@ pub struct CreateTableFunctionInfo {
/// Functions with different arguments
pub(crate) functions: Vec<TableFunction>,
}

#[derive(new)]
pub struct CreateScalarFunctionInfo {
pub(crate) base: CreateInfoBase,
/// Function name
pub(crate) name: String,
/// Functions with different arguments
pub(crate) functions: Vec<ScalarFunction>,
}
9 changes: 9 additions & 0 deletions src/execution/expression_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ impl ExpressionExecutor {
let options = CastOptions { safe: e.try_cast };
cast_with_options(&child_result, &to_type, &options)?
}
BoundExpression::BoundFunctionExpression(e) => {
let children_result = e
.children
.iter()
.map(|c| Self::execute_internal(c, input))
.collect::<Result<Vec<_>, _>>()?;
let func = e.function.function;
func(&children_result)?
}
})
}
}
21 changes: 20 additions & 1 deletion src/function/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
mod errors;
mod scalar;
mod table;

use std::sync::Arc;

use derive_new::new;
pub use errors::*;
pub use scalar::*;
pub use table::*;

use crate::catalog_v2::{Catalog, DEFAULT_SCHEMA};
use crate::common::{CreateInfoBase, CreateTableFunctionInfo};
use crate::common::{CreateInfoBase, CreateScalarFunctionInfo, CreateTableFunctionInfo};
use crate::main_entry::ClientContext;

#[derive(Debug, Clone)]
Expand All @@ -33,9 +35,26 @@ impl BuiltinFunctions {
Ok(Catalog::create_table_function(self.context.clone(), info)?)
}

pub fn add_scalar_functions(
&mut self,
function_name: String,
functions: Vec<ScalarFunction>,
) -> Result<(), FunctionError> {
let info = CreateScalarFunctionInfo::new(
CreateInfoBase::new(DEFAULT_SCHEMA.to_string()),
function_name,
functions,
);
Ok(Catalog::create_scalar_function(self.context.clone(), info)?)
}

pub fn initialize(&mut self) -> Result<(), FunctionError> {
SqlrsTablesFunc::register_function(self)?;
SqlrsColumnsFunc::register_function(self)?;
AddFunction::register_function(self)?;
SubtractFunction::register_function(self)?;
MultiplyFunction::register_function(self)?;
DivideFunction::register_function(self)?;
Ok(())
}
}
157 changes: 157 additions & 0 deletions src/function/scalar/arithmetic_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
use std::sync::Arc;

use arrow::array::{ArrayRef, *};
use arrow::compute::{add_checked, divide_checked, multiply_checked, subtract_checked};
use arrow::datatypes::DataType;

use super::ScalarFunction;
use crate::function::{BuiltinFunctions, FunctionError};
use crate::types_v2::LogicalType;

/// Invoke a compute kernel on array(s)
macro_rules! compute_op {
// invoke binary operator
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
Ok(Arc::new($OP(&ll, &rr)?))
}};
// invoke unary operator
($OPERAND:expr, $OP:ident, $DT:ident) => {{
let operand = $OPERAND
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
Ok(Arc::new($OP(&operand)?))
}};
}

/// Invoke a compute kernel on a pair of arrays
/// The binary_primitive_array_op macro only evaluates for primitive types
/// like integers and floats.
macro_rules! binary_primitive_array_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
match $LEFT.data_type() {
DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
other => Err(FunctionError::InternalError(format!(
"Data type {:?} not supported for binary operation '{}' on primitive arrays",
other,
stringify!($OP)
))),
}
}};
}
pub struct AddFunction;

impl AddFunction {
fn add(inputs: &[ArrayRef]) -> Result<ArrayRef, FunctionError> {
assert!(inputs.len() == 2);
let left = &inputs[0];
let right = &inputs[1];
binary_primitive_array_op!(left, right, add_checked)
}

pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> {
let mut functions = vec![];
for ty in LogicalType::numeric().iter() {
functions.push(ScalarFunction::new(
"add".to_string(),
Self::add,
vec![ty.clone(), ty.clone()],
ty.clone(),
));
}
set.add_scalar_functions("add".to_string(), functions.clone())?;
Ok(())
}
}

pub struct SubtractFunction;

impl SubtractFunction {
fn subtract(inputs: &[ArrayRef]) -> Result<ArrayRef, FunctionError> {
assert!(inputs.len() == 2);
let left = &inputs[0];
let right = &inputs[1];
binary_primitive_array_op!(left, right, subtract_checked)
}

pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> {
let mut functions = vec![];
for ty in LogicalType::numeric().iter() {
functions.push(ScalarFunction::new(
"subtract".to_string(),
Self::subtract,
vec![ty.clone(), ty.clone()],
ty.clone(),
));
}
set.add_scalar_functions("subtract".to_string(), functions.clone())?;
Ok(())
}
}

pub struct MultiplyFunction;

impl MultiplyFunction {
fn multiply(inputs: &[ArrayRef]) -> Result<ArrayRef, FunctionError> {
assert!(inputs.len() == 2);
let left = &inputs[0];
let right = &inputs[1];
binary_primitive_array_op!(left, right, multiply_checked)
}

pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> {
let mut functions = vec![];
for ty in LogicalType::numeric().iter() {
functions.push(ScalarFunction::new(
"multiply".to_string(),
Self::multiply,
vec![ty.clone(), ty.clone()],
ty.clone(),
));
}
set.add_scalar_functions("multiply".to_string(), functions.clone())?;
Ok(())
}
}

pub struct DivideFunction;

impl DivideFunction {
fn divide(inputs: &[ArrayRef]) -> Result<ArrayRef, FunctionError> {
assert!(inputs.len() == 2);
let left = &inputs[0];
let right = &inputs[1];
binary_primitive_array_op!(left, right, divide_checked)
}

pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> {
let mut functions = vec![];
for ty in LogicalType::numeric().iter() {
functions.push(ScalarFunction::new(
"divide".to_string(),
Self::divide,
vec![ty.clone(), ty.clone()],
ty.clone(),
));
}
set.add_scalar_functions("divide".to_string(), functions.clone())?;
Ok(())
}
}
4 changes: 4 additions & 0 deletions src/function/scalar/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod arithmetic_function;
mod scalar_function;
pub use arithmetic_function::*;
pub use scalar_function::*;
Loading