Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ pub enum BuiltinScalarFunction {
ArrayExcept,
/// cardinality
Cardinality,
/// array_resize
ArrayResize,
/// construct an array from columns
MakeArray,
/// Flatten
Expand Down Expand Up @@ -430,6 +432,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
BuiltinScalarFunction::ArrayUnion => Volatility::Immutable,
BuiltinScalarFunction::ArrayResize => Volatility::Immutable,
BuiltinScalarFunction::Range => Volatility::Immutable,
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
Expand Down Expand Up @@ -617,6 +620,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayResize => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
BuiltinScalarFunction::ArrayIntersect => {
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
Expand Down Expand Up @@ -980,6 +984,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayResize => {
Signature::variadic_any(self.volatility())
}

BuiltinScalarFunction::Range => Signature::one_of(
vec![
Exact(vec![Int64]),
Expand Down Expand Up @@ -1647,6 +1655,7 @@ impl BuiltinScalarFunction {
],
BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"],
BuiltinScalarFunction::Cardinality => &["cardinality"],
BuiltinScalarFunction::ArrayResize => &["array_resize", "list_resize"],
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
BuiltinScalarFunction::ArrayIntersect => {
&["array_intersect", "list_intersect"]
Expand Down
8 changes: 8 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,14 @@ scalar_expr!(
array,
"returns the total number of elements in the array."
);

scalar_expr!(
ArrayResize,
array_resize,
array size value,
"returns an array with the specified size filled with the given value."
);

nary_scalar_expr!(
MakeArray,
array,
Expand Down
113 changes: 107 additions & 6 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use arrow::buffer::OffsetBuffer;
use arrow::compute;
use arrow::datatypes::{DataType, Field, UInt64Type};
use arrow::row::{RowConverter, SortField};
use arrow_buffer::NullBuffer;
use arrow_buffer::{ArrowNativeType, NullBuffer};

use arrow_schema::{FieldRef, SortOptions};
use datafusion_common::cast::{
Expand All @@ -36,7 +36,8 @@ use datafusion_common::cast::{
};
use datafusion_common::utils::{array_into_list_array, list_ndims};
use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result,
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
DataFusionError, Result, ScalarValue,
};

use itertools::Itertools;
Expand Down Expand Up @@ -1190,7 +1191,10 @@ pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

concat_internal::<i32>(new_args.as_slice())
match &args[0].data_type() {
DataType::LargeList(_) => concat_internal::<i64>(new_args.as_slice()),
_ => concat_internal::<i32>(new_args.as_slice()),
}
}

/// Array_empty SQL function
Expand Down Expand Up @@ -1239,7 +1243,7 @@ pub fn array_repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_large_list_array(element)?;
general_list_repeat::<i64>(list_array, count_array)
}
_ => general_repeat(element, count_array),
_ => general_repeat::<i32>(element, count_array),
}
}

Expand All @@ -1255,7 +1259,10 @@ pub fn array_repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]]
/// )
/// ```
fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result<ArrayRef> {
fn general_repeat<O: OffsetSizeTrait>(
array: &ArrayRef,
count_array: &Int64Array,
) -> Result<ArrayRef> {
let data_type = array.data_type();
let mut new_values = vec![];

Expand Down Expand Up @@ -1288,7 +1295,7 @@ fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result<ArrayRef
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = compute::concat(&new_values)?;

Ok(Arc::new(ListArray::try_new(
Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
OffsetBuffer::from_lengths(count_vec),
values,
Expand Down Expand Up @@ -2611,6 +2618,100 @@ pub fn array_distinct(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

/// array_resize SQL function
pub fn array_resize(arg: &[ArrayRef]) -> Result<ArrayRef> {
if arg.len() < 2 || arg.len() > 3 {
return exec_err!("array_resize needs two or three arguments");
}

let new_len = as_int64_array(&arg[1])?;
let new_element = if arg.len() == 3 {
Some(arg[2].clone())
} else {
None
};

match &arg[0].data_type() {
DataType::List(field) => {
let array = as_list_array(&arg[0])?;
general_list_resize::<i32>(array, new_len, field, new_element)
}
DataType::LargeList(field) => {
let array = as_large_list_array(&arg[0])?;
general_list_resize::<i64>(array, new_len, field, new_element)
}
array_type => exec_err!("array_resize does not support type '{array_type:?}'."),
}
}

/// array_resize keep the original array and append the default element to the end
fn general_list_resize<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
count_array: &Int64Array,
field: &FieldRef,
default_element: Option<ArrayRef>,
) -> Result<ArrayRef>
where
O: TryInto<i64>,
{
let data_type = array.value_type();

let values = array.values();
let original_data = values.to_data();

// create default element array
let default_element = if let Some(default_element) = default_element {
default_element
} else {
let null_scalar = ScalarValue::try_from(&data_type)?;
null_scalar.to_array_of_size(original_data.len())?
};
let default_value_data = default_element.to_data();

// create a mutable array to store the original data
let capacity = Capacities::Array(original_data.len() + default_value_data.len());
let mut offsets = vec![O::usize_as(0)];
let mut mutable = MutableArrayData::with_capacities(
vec![&original_data, &default_value_data],
false,
capacity,
);

for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
let count = count_array.value(row_index).to_usize().ok_or_else(|| {
internal_datafusion_err!("array_resize: failed to convert size to usize")
})?;
let count = O::usize_as(count);
let start = offset_window[0];
if start + count > offset_window[1] {
let extra_count =
(start + count - offset_window[1]).try_into().map_err(|_| {
internal_datafusion_err!(
"array_resize: failed to convert size to i64"
)
})?;
let end = offset_window[1];
mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
// append default element
for _ in 0..extra_count {
mutable.extend(1, row_index, row_index + 1);
}
} else {
let end = start + count;
mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
};
offsets.push(offsets[row_index] + count);
}

let data = mutable.freeze();
Ok(Arc::new(GenericListArray::<O>::try_new(
field.clone(),
OffsetBuffer::<O>::new(offsets.into()),
arrow_array::make_array(data),
None,
)?))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Cardinality => {
Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args))
}
BuiltinScalarFunction::ArrayResize => {
Arc::new(|args| make_scalar_function(array_expressions::array_resize)(args))
}
BuiltinScalarFunction::MakeArray => {
Arc::new(|args| make_scalar_function(array_expressions::make_array)(args))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ enum ScalarFunction {
FindInSet = 127;
ArraySort = 128;
ArrayDistinct = 129;
ArrayResize = 130;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ArrayToString => Self::ArrayToString,
ScalarFunction::ArrayIntersect => Self::ArrayIntersect,
ScalarFunction::ArrayUnion => Self::ArrayUnion,
ScalarFunction::ArrayResize => Self::ArrayResize,
ScalarFunction::Range => Self::Range,
ScalarFunction::Cardinality => Self::Cardinality,
ScalarFunction::Array => Self::MakeArray,
Expand Down Expand Up @@ -1499,6 +1500,11 @@ pub fn parse_expr(
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
ScalarFunction::ArrayResize => Ok(array_slice(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
parse_expr(&args[2], registry)?,
)),
ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)),
ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)),
ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)),
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ArrayPositions => Self::ArrayPositions,
BuiltinScalarFunction::ArrayPrepend => Self::ArrayPrepend,
BuiltinScalarFunction::ArrayRepeat => Self::ArrayRepeat,
BuiltinScalarFunction::ArrayResize => Self::ArrayResize,
BuiltinScalarFunction::ArrayRemove => Self::ArrayRemove,
BuiltinScalarFunction::ArrayRemoveN => Self::ArrayRemoveN,
BuiltinScalarFunction::ArrayRemoveAll => Self::ArrayRemoveAll,
Expand Down
Loading