diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index e19e274341a55..f3fa5b2c5de5c 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -293,6 +293,155 @@ impl ScalarValue { self.to_array_of_size(1) } + /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] + /// corresponding to those values. For example, + /// + /// Returns an error if the iterator is empty or if the + /// [`ScalarValue`]s are not all the same type + /// + /// Example + /// ``` + /// use datafusion::scalar::ScalarValue; + /// use arrow::array::{ArrayRef, BooleanArray}; + /// + /// let scalars = vec![ + /// ScalarValue::Boolean(Some(true)), + /// ScalarValue::Boolean(None), + /// ScalarValue::Boolean(Some(false)), + /// ]; + /// + /// // Build an Array from the list of ScalarValues + /// let array = ScalarValue::iter_to_array(scalars.iter()) + /// .unwrap(); + /// + /// let expected: ArrayRef = std::sync::Arc::new( + /// BooleanArray::from(vec![ + /// Some(true), + /// None, + /// Some(false) + /// ] + /// )); + /// + /// assert_eq!(&array, &expected); + /// ``` + pub fn iter_to_array<'a>( + scalars: impl IntoIterator, + ) -> Result { + let mut scalars = scalars.into_iter().peekable(); + + // figure out the type based on the first element + let data_type = match scalars.peek() { + None => { + return Err(DataFusionError::Internal( + "Empty iterator passed to ScalarValue::iter_to_array".to_string(), + )) + } + Some(sv) => sv.get_datatype(), + }; + + /// Creates an array of $ARRAY_TY by unpacking values of + /// SCALAR_TY for primitive types + macro_rules! build_array_primitive { + ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + { + let values = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(*v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>>()?; + + let array: $ARRAY_TY = values.iter().collect(); + Arc::new(array) + } + }}; + } + + /// Creates an array of $ARRAY_TY by unpacking values of + /// SCALAR_TY for "string-like" types. + macro_rules! build_array_string { + ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + { + let values = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>>()?; + + // it is annoying that one can not create + // StringArray et al directly from iter of &String, + // requiring this map to &str + let values = values.iter().map(|s| s.as_ref()); + + let array: $ARRAY_TY = values.collect(); + Arc::new(array) + } + }}; + } + + let array: ArrayRef = match &data_type { + DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), + DataType::Float32 => build_array_primitive!(Float32Array, Float32), + DataType::Float64 => build_array_primitive!(Float64Array, Float64), + DataType::Int8 => build_array_primitive!(Int8Array, Int8), + DataType::Int16 => build_array_primitive!(Int16Array, Int16), + DataType::Int32 => build_array_primitive!(Int32Array, Int32), + DataType::Int64 => build_array_primitive!(Int64Array, Int64), + DataType::UInt8 => build_array_primitive!(UInt8Array, UInt8), + DataType::UInt16 => build_array_primitive!(UInt16Array, UInt16), + DataType::UInt32 => build_array_primitive!(UInt32Array, UInt32), + DataType::UInt64 => build_array_primitive!(UInt64Array, UInt64), + DataType::Utf8 => build_array_string!(StringArray, Utf8), + DataType::LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8), + DataType::Binary => build_array_string!(BinaryArray, Binary), + DataType::LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), + DataType::Date32 => build_array_primitive!(Date32Array, Date32), + DataType::Date64 => build_array_primitive!(Date64Array, Date64), + DataType::Timestamp(TimeUnit::Second, None) => { + build_array_primitive!(TimestampSecondArray, TimestampSecond) + } + DataType::Timestamp(TimeUnit::Millisecond, None) => { + build_array_primitive!(TimestampMillisecondArray, TimestampMillisecond) + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + build_array_primitive!(TimestampMicrosecondArray, TimestampMicrosecond) + } + DataType::Timestamp(TimeUnit::Nanosecond, None) => { + build_array_primitive!(TimestampNanosecondArray, TimestampNanosecond) + } + DataType::Interval(IntervalUnit::DayTime) => { + build_array_primitive!(IntervalDayTimeArray, IntervalDayTime) + } + DataType::Interval(IntervalUnit::YearMonth) => { + build_array_primitive!(IntervalYearMonthArray, IntervalYearMonth) + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported creation of {:?} array from ScalarValue {:?}", + data_type, + scalars.peek() + ))) + } + }; + + Ok(array) + } + /// Converts a scalar value into an array of `size` rows. pub fn to_array_of_size(&self, size: usize) -> ArrayRef { match self { @@ -609,6 +758,12 @@ impl From for ScalarValue { } } +impl From<&str> for ScalarValue { + fn from(value: &str) -> Self { + ScalarValue::Utf8(Some(value.to_string())) + } +} + macro_rules! impl_try_from { ($SCALAR:ident, $NATIVE:ident) => { impl TryFrom for $NATIVE { @@ -940,4 +1095,139 @@ mod tests { assert!(prim_array.is_null(1)); assert_eq!(prim_array.value(2), 101); } + + /// Creates array directly and via ScalarValue and ensures they are the same + macro_rules! check_scalar_iter { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = + $INPUT.iter().map(|v| ScalarValue::$SCALAR_T(*v)).collect(); + + let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + + let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they + /// are the same, for string arrays + macro_rules! check_scalar_iter_string { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_string()))) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + + let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); + + assert_eq!(&array, &expected); + }}; + } + + /// Creates array directly and via ScalarValue and ensures they + /// are the same, for binary arrays + macro_rules! check_scalar_iter_binary { + ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ + let scalars: Vec<_> = $INPUT + .iter() + .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_vec()))) + .collect(); + + let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + + let expected: $ARRAYTYPE = + $INPUT.iter().map(|v| v.map(|v| v.to_vec())).collect(); + + let expected: ArrayRef = Arc::new(expected); + + assert_eq!(&array, &expected); + }}; + } + + #[test] + fn scalar_iter_to_array_boolean() { + check_scalar_iter!(Boolean, BooleanArray, vec![Some(true), None, Some(false)]); + check_scalar_iter!(Float32, Float32Array, vec![Some(1.9), None, Some(-2.1)]); + check_scalar_iter!(Float64, Float64Array, vec![Some(1.9), None, Some(-2.1)]); + + check_scalar_iter!(Int8, Int8Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int16, Int16Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int32, Int32Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(Int64, Int64Array, vec![Some(1), None, Some(3)]); + + check_scalar_iter!(UInt8, UInt8Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt16, UInt16Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt32, UInt32Array, vec![Some(1), None, Some(3)]); + check_scalar_iter!(UInt64, UInt64Array, vec![Some(1), None, Some(3)]); + + check_scalar_iter!( + TimestampSecond, + TimestampSecondArray, + vec![Some(1), None, Some(3)] + ); + check_scalar_iter!( + TimestampMillisecond, + TimestampMillisecondArray, + vec![Some(1), None, Some(3)] + ); + check_scalar_iter!( + TimestampMicrosecond, + TimestampMicrosecondArray, + vec![Some(1), None, Some(3)] + ); + check_scalar_iter!( + TimestampNanosecond, + TimestampNanosecondArray, + vec![Some(1), None, Some(3)] + ); + + check_scalar_iter_string!( + Utf8, + StringArray, + vec![Some("foo"), None, Some("bar")] + ); + check_scalar_iter_string!( + LargeUtf8, + LargeStringArray, + vec![Some("foo"), None, Some("bar")] + ); + check_scalar_iter_binary!( + Binary, + BinaryArray, + vec![Some(b"foo"), None, Some(b"bar")] + ); + check_scalar_iter_binary!( + LargeBinary, + LargeBinaryArray, + vec![Some(b"foo"), None, Some(b"bar")] + ); + } + + #[test] + fn scalar_iter_to_array_empty() { + let scalars = vec![] as Vec; + + let result = ScalarValue::iter_to_array(scalars.iter()).unwrap_err(); + assert!( + result + .to_string() + .contains("Empty iterator passed to ScalarValue::iter_to_array"), + "{}", + result + ); + } + + #[test] + fn scalar_iter_to_array_mismatched_types() { + use ScalarValue::*; + // If the scalar values are not all the correct type, error here + let scalars: Vec = vec![Boolean(Some(true)), Int32(Some(5))]; + + let result = ScalarValue::iter_to_array(scalars.iter()).unwrap_err(); + assert!(result.to_string().contains("Inconsistent types in ScalarValue::iter_to_array. Expected Boolean, got Int32(5)"), + "{}", result); + } }