Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions rust/datafusion/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use arrow::{
util::pretty,
};

use datafusion::error::Result;
use datafusion::{physical_plan::functions::ScalarFunctionImplementation, prelude::*};
use datafusion::prelude::*;
use datafusion::{error::Result, physical_plan::functions::make_scalar_function};
use std::sync::Arc;

// create local execution context with an in-memory table
Expand Down Expand Up @@ -60,7 +60,7 @@ async fn main() -> Result<()> {
let mut ctx = create_context()?;

// First, declare the actual implementation of the calculation
let pow: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| {
let pow = |args: &[ArrayRef]| {
// in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to:
// 1. cast the values to the type we want
// 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result
Expand Down Expand Up @@ -97,8 +97,11 @@ async fn main() -> Result<()> {

// `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!)
// `Arc` because arrays are immutable, thread-safe, trait objects.
Ok(Arc::new(array))
});
Ok(Arc::new(array) as ArrayRef)
};
// the function above expects an `ArrayRef`, but DataFusion may pass a scalar to a UDF.
// thus, we use `make_scalar_function` to decorare the closure so that it can handle both Arrays and Scalar values.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let pow = make_scalar_function(pow);

// Next:
// * give it a name so that it shows nicely when the plan is printed
Expand Down
9 changes: 5 additions & 4 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ impl FunctionRegistry for ExecutionContextState {
mod tests {

use super::*;
use crate::physical_plan::functions::ScalarFunctionImplementation;
use crate::physical_plan::functions::make_scalar_function;
use crate::physical_plan::{collect, collect_partitioned};
use crate::test;
use crate::variable::VarType;
Expand Down Expand Up @@ -1618,7 +1618,7 @@ mod tests {
let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?;
ctx.register_table("t", Box::new(provider));

let myfunc: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| {
let myfunc = |args: &[ArrayRef]| {
let l = &args[0]
.as_any()
.downcast_ref::<Int32Array>()
Expand All @@ -1627,8 +1627,9 @@ mod tests {
.as_any()
.downcast_ref::<Int32Array>()
.expect("cast failed");
Ok(Arc::new(add(l, r)?))
});
Ok(Arc::new(add(l, r)?) as ArrayRef)
};
let myfunc = make_scalar_function(myfunc);

ctx.register_udf(create_udf(
"my_add",
Expand Down
6 changes: 3 additions & 3 deletions rust/datafusion/src/execution/dataframe_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ impl DataFrame for DataFrameImpl {
#[cfg(test)]
mod tests {
use super::*;
use crate::datasource::csv::CsvReadOptions;
use crate::execution::context::ExecutionContext;
use crate::logical_plan::*;
use crate::{datasource::csv::CsvReadOptions, physical_plan::ColumnarValue};
use crate::{physical_plan::functions::ScalarFunctionImplementation, test};
use arrow::{array::ArrayRef, datatypes::DataType};
use arrow::datatypes::DataType;

#[test]
fn select_columns() -> Result<()> {
Expand Down Expand Up @@ -287,7 +287,7 @@ mod tests {

// declare the udf
let my_fn: ScalarFunctionImplementation =
Arc::new(|_: &[ArrayRef]| unimplemented!("my_fn is not implemented"));
Arc::new(|_: &[ColumnarValue]| unimplemented!("my_fn is not implemented"));

// create and register the udf
ctx.register_udf(create_udf(
Expand Down
23 changes: 21 additions & 2 deletions rust/datafusion/src/physical_plan/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use arrow::array::*;
use arrow::datatypes::DataType;
use std::sync::Arc;

use super::ColumnarValue;

macro_rules! downcast_vec {
($ARGS:expr, $ARRAY_TYPE:ident) => {{
$ARGS
Expand Down Expand Up @@ -58,8 +60,7 @@ macro_rules! array {
}};
}

/// put values in an array.
pub fn array(args: &[ArrayRef]) -> Result<ArrayRef> {
fn array_array(args: &[&dyn Array]) -> Result<ArrayRef> {
// do not accept 0 arguments.
if args.is_empty() {
return Err(DataFusionError::Internal(
Expand Down Expand Up @@ -88,6 +89,24 @@ pub fn array(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

/// put values in an array.
pub fn array(values: &[ColumnarValue]) -> Result<ColumnarValue> {
let arrays: Vec<&dyn Array> = values
.iter()
.map(|value| {
if let ColumnarValue::Array(value) = value {
Ok(value.as_ref())
} else {
Err(DataFusionError::NotImplemented(
"Array is not implemented for scalar values.".to_string(),
))
}
})
.collect::<Result<_>>()?;

Ok(ColumnarValue::Array(array_array(&arrays)?))
}

/// Currently supported types by the array function.
/// The order of these types correspond to the order on which coercion applies
/// This should thus be from least informative to most informative
Expand Down
191 changes: 142 additions & 49 deletions rust/datafusion/src/physical_plan/crypto_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,26 @@

//! Crypto expressions

use std::sync::Arc;

use md5::Md5;
use sha2::{
digest::Output as SHA2DigestOutput, Digest as SHA2Digest, Sha224, Sha256, Sha384,
Sha512,
};

use crate::error::{DataFusionError, Result};
use arrow::array::{
ArrayRef, GenericBinaryArray, GenericStringArray, StringOffsetSizeTrait,
use crate::{
error::{DataFusionError, Result},
scalar::ScalarValue,
};
use arrow::{
array::{Array, BinaryArray, GenericStringArray, StringOffsetSizeTrait},
datatypes::DataType,
};

use super::{string_expressions::unary_string_function, ColumnarValue};

/// Computes the md5 of a string.
fn md5_process(input: &str) -> String {
let mut digest = Md5::default();
digest.update(&input);
Expand All @@ -49,58 +58,142 @@ fn sha_process<D: SHA2Digest + Default>(input: &str) -> SHA2DigestOutput<D> {
digest.finalize()
}

macro_rules! crypto_unary_string_function {
($NAME:ident, $FUNC:expr) => {
/// crypto function that accepts Utf8 or LargeUtf8 and returns Utf8 string
pub fn $NAME<T: StringOffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<GenericStringArray<i32>> {
if args.len() != 1 {
return Err(DataFusionError::Internal(format!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
String::from(stringify!($NAME)),
)));
}
/// # Errors
/// This function errors when:
/// * the number of arguments is not 1
/// * the first argument is not castable to a `GenericStringArray`
fn unary_binary_function<T, R, F>(
args: &[&dyn Array],
op: F,
name: &str,
) -> Result<BinaryArray>
where
R: AsRef<[u8]>,
T: StringOffsetSizeTrait,
F: Fn(&str) -> R,
{
if args.len() != 1 {
return Err(DataFusionError::Internal(format!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
name,
)));
}

let array = args[0]
.as_any()
.downcast_ref::<GenericStringArray<T>>()
.ok_or_else(|| {
DataFusionError::Internal("failed to downcast to string".to_string())
})?;

let array = args[0]
.as_any()
.downcast_ref::<GenericStringArray<T>>()
.unwrap();
// first map is the iterator, second is for the `Option<_>`
Ok(array.iter().map(|x| x.map(|x| op(x))).collect())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Ok(array.iter().map(|x| x.map(|x| op(x))).collect())
Ok(array.iter().map(|x| x.map(op)).collect())

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, though, I also though it would work. However, because the functions have different signatures, a deref is needed and thus we need to write it explicitly. Same for md5_process.

}

fn handle<F, R>(args: &[ColumnarValue], op: F, name: &str) -> Result<ColumnarValue>
where
R: AsRef<[u8]>,
F: Fn(&str) -> R,
{
match &args[0] {
ColumnarValue::Array(a) => match a.data_type() {
DataType::Utf8 => {
Ok(ColumnarValue::Array(Arc::new(unary_binary_function::<
i32,
_,
_,
>(
&[a.as_ref()], op, name
)?)))
}
DataType::LargeUtf8 => {
Ok(ColumnarValue::Array(Arc::new(unary_binary_function::<
i64,
_,
_,
>(
&[a.as_ref()], op, name
)?)))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function {}",
other, name,
))),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) => {
let result = a.as_ref().map(|x| (op)(x).as_ref().to_vec());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let result = a.as_ref().map(|x| (op)(x).as_ref().to_vec());
let result = a.as_ref().map(|x| op(x).as_ref().to_vec());

Ok(ColumnarValue::Scalar(ScalarValue::Binary(result)))
}
ScalarValue::LargeUtf8(a) => {
let result = a.as_ref().map(|x| (op)(x).as_ref().to_vec());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let result = a.as_ref().map(|x| (op)(x).as_ref().to_vec());
let result = a.as_ref().map(|x| op(x).as_ref().to_vec());

Ok(ColumnarValue::Scalar(ScalarValue::Binary(result)))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function {}",
other, name,
))),
},
}
}

// first map is the iterator, second is for the `Option<_>`
Ok(array.iter().map(|x| x.map(|x| $FUNC(x))).collect())
}
};
fn md5_array<T: StringOffsetSizeTrait>(
args: &[&dyn Array],
) -> Result<GenericStringArray<i32>> {
unary_string_function::<T, i32, _, _>(args, md5_process, "md5")
}

macro_rules! crypto_unary_binary_function {
($NAME:ident, $FUNC:expr) => {
/// crypto function that accepts Utf8 or LargeUtf8 and returns Binary
pub fn $NAME<T: StringOffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<GenericBinaryArray<i32>> {
if args.len() != 1 {
return Err(DataFusionError::Internal(format!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
String::from(stringify!($NAME)),
)));
/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`]
pub fn md5(args: &[ColumnarValue]) -> Result<ColumnarValue> {
match &args[0] {
ColumnarValue::Array(a) => match a.data_type() {
DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new(md5_array::<i32>(&[
a.as_ref()
])?))),
DataType::LargeUtf8 => {
Ok(ColumnarValue::Array(Arc::new(md5_array::<i64>(&[
a.as_ref()
])?)))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function md5",
other,
))),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) => {
let result = a.as_ref().map(|x| md5_process(x));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let result = a.as_ref().map(|x| md5_process(x));
let result = a.as_ref().map(md5_process);

Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
}
ScalarValue::LargeUtf8(a) => {
let result = a.as_ref().map(|x| md5_process(x));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let result = a.as_ref().map(|x| md5_process(x));
let result = a.as_ref().map(md5_process);

Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function md5",
other,
))),
},
}
}

let array = args[0]
.as_any()
.downcast_ref::<GenericStringArray<T>>()
.unwrap();
/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`]
pub fn sha224(args: &[ColumnarValue]) -> Result<ColumnarValue> {
handle(args, sha_process::<Sha224>, "ssh224")
}

// first map is the iterator, second is for the `Option<_>`
Ok(array.iter().map(|x| x.map(|x| $FUNC(x))).collect())
}
};
/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`]
pub fn sha256(args: &[ColumnarValue]) -> Result<ColumnarValue> {
handle(args, sha_process::<Sha256>, "sha256")
}

crypto_unary_string_function!(md5, md5_process);
crypto_unary_binary_function!(sha224, sha_process::<Sha224>);
crypto_unary_binary_function!(sha256, sha_process::<Sha256>);
crypto_unary_binary_function!(sha384, sha_process::<Sha384>);
crypto_unary_binary_function!(sha512, sha_process::<Sha512>);
/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`]
pub fn sha384(args: &[ColumnarValue]) -> Result<ColumnarValue> {
handle(args, sha_process::<Sha384>, "sha384")
}

/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`]
pub fn sha512(args: &[ColumnarValue]) -> Result<ColumnarValue> {
handle(args, sha_process::<Sha512>, "sha512")
}
Loading