diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index bca3dbde3276b..d22e2fb013707 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -23,10 +23,12 @@ use crate::{downcast_value, DataFusionError}; use arrow::{ array::{ - Array, BooleanArray, Date32Array, Decimal128Array, DictionaryArray, Float32Array, - Float64Array, GenericBinaryArray, GenericListArray, Int32Array, Int64Array, - LargeListArray, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, - StructArray, UInt32Array, UInt64Array, + Array, BinaryArray, BooleanArray, Date32Array, Decimal128Array, DictionaryArray, + Float32Array, Float64Array, GenericBinaryArray, GenericListArray, Int32Array, + Int64Array, LargeListArray, ListArray, MapArray, NullArray, OffsetSizeTrait, + PrimitiveArray, StringArray, StructArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + UInt32Array, UInt64Array, UnionArray, }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; @@ -127,3 +129,51 @@ pub fn as_primitive_array( ) -> Result<&PrimitiveArray, DataFusionError> { Ok(downcast_value!(array, PrimitiveArray, T)) } + +// Downcast ArrayRef to MapArray +pub fn as_map_array(array: &dyn Array) -> Result<&MapArray, DataFusionError> { + Ok(downcast_value!(array, MapArray)) +} + +// Downcast ArrayRef to NullArray +pub fn as_null_array(array: &dyn Array) -> Result<&NullArray, DataFusionError> { + Ok(downcast_value!(array, NullArray)) +} + +// Downcast ArrayRef to NullArray +pub fn as_union_array(array: &dyn Array) -> Result<&UnionArray, DataFusionError> { + Ok(downcast_value!(array, UnionArray)) +} + +// Downcast ArrayRef to TimestampNanosecondArray +pub fn as_timestamp_nanosecond_array( + array: &dyn Array, +) -> Result<&TimestampNanosecondArray, DataFusionError> { + Ok(downcast_value!(array, TimestampNanosecondArray)) +} + +// Downcast ArrayRef to TimestampMillisecondArray +pub fn as_timestamp_millisecond_array( + array: &dyn Array, +) -> Result<&TimestampMillisecondArray, DataFusionError> { + Ok(downcast_value!(array, TimestampMillisecondArray)) +} + +// Downcast ArrayRef to TimestampMicrosecondArray +pub fn as_timestamp_microsecond_array( + array: &dyn Array, +) -> Result<&TimestampMicrosecondArray, DataFusionError> { + Ok(downcast_value!(array, TimestampMicrosecondArray)) +} + +// Downcast ArrayRef to TimestampSecondArray +pub fn as_timestamp_second_array( + array: &dyn Array, +) -> Result<&TimestampSecondArray, DataFusionError> { + Ok(downcast_value!(array, TimestampSecondArray)) +} + +// Downcast ArrayRef to BinaryArray +pub fn as_binary_array(array: &dyn Array) -> Result<&BinaryArray, DataFusionError> { + Ok(downcast_value!(array, BinaryArray)) +} diff --git a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs index 30c32fc572a8e..268157e8c762d 100644 --- a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs @@ -975,9 +975,10 @@ mod test { use crate::arrow::array::Array; use crate::arrow::datatypes::{Field, TimeUnit}; use crate::avro_to_arrow::{Reader, ReaderBuilder}; - use arrow::array::TimestampMicrosecondArray; use arrow::datatypes::DataType; - use datafusion_common::cast::{as_int32_array, as_int64_array, as_list_array}; + use datafusion_common::cast::{ + as_int32_array, as_int64_array, as_list_array, as_timestamp_microsecond_array, + }; use std::fs::File; fn build_reader(name: &str, batch_size: usize) -> Reader { @@ -1008,11 +1009,8 @@ mod test { &DataType::Timestamp(TimeUnit::Microsecond, None), timestamp_col.1.data_type() ); - let timestamp_array = batch - .column(timestamp_col.0) - .as_any() - .downcast_ref::() - .unwrap(); + let timestamp_array = + as_timestamp_microsecond_array(batch.column(timestamp_col.0)).unwrap(); for i in 0..timestamp_array.len() { assert!(timestamp_array.is_valid(i)); } diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index c4dbf873b7178..ac1ea1ba86268 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -92,9 +92,9 @@ mod tests { use crate::datasource::file_format::test_util::scan_format; use crate::physical_plan::collect; use crate::prelude::{SessionConfig, SessionContext}; - use arrow::array::{BinaryArray, TimestampMicrosecondArray}; use datafusion_common::cast::{ - as_boolean_array, as_float32_array, as_float64_array, as_int32_array, + as_binary_array, as_boolean_array, as_float32_array, as_float64_array, + as_int32_array, as_timestamp_microsecond_array, }; use futures::StreamExt; @@ -248,11 +248,7 @@ mod tests { assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); - let array = batches[0] - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_timestamp_microsecond_array(batches[0].column(0))?; let mut values: Vec = vec![]; for i in 0..batches[0].num_rows() { values.push(array.value(i)); @@ -327,11 +323,7 @@ mod tests { assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); - let array = batches[0] - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_binary_array(batches[0].column(0))?; let mut values: Vec<&str> = vec![]; for i in 0..batches[0].num_rows() { values.push(std::str::from_utf8(array.value(i)).unwrap()); diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index a1b5307dd54ab..fa9ab13cd1151 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -585,14 +585,13 @@ mod tests { use crate::datasource::file_format::parquet::test_util::store_parquet; use crate::physical_plan::metrics::MetricValue; use crate::prelude::{SessionConfig, SessionContext}; - use arrow::array::{ - Array, ArrayRef, BinaryArray, StringArray, TimestampNanosecondArray, - }; + use arrow::array::{Array, ArrayRef, StringArray}; use arrow::record_batch::RecordBatch; use async_trait::async_trait; use bytes::Bytes; use datafusion_common::cast::{ - as_boolean_array, as_float32_array, as_float64_array, as_int32_array, + as_binary_array, as_boolean_array, as_float32_array, as_float64_array, + as_int32_array, as_timestamp_nanosecond_array, }; use datafusion_common::ScalarValue; use futures::stream::BoxStream; @@ -996,11 +995,7 @@ mod tests { assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); - let array = batches[0] - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_timestamp_nanosecond_array(batches[0].column(0))?; let mut values: Vec = vec![]; for i in 0..batches[0].num_rows() { values.push(array.value(i)); @@ -1075,11 +1070,7 @@ mod tests { assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); - let array = batches[0] - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_binary_array(batches[0].column(0))?; let mut values: Vec<&str> = vec![]; for i in 0..batches[0].num_rows() { values.push(std::str::from_utf8(array.value(i)).unwrap()); diff --git a/datafusion/physical-expr/src/crypto_expressions.rs b/datafusion/physical-expr/src/crypto_expressions.rs index 89422399e2e80..228e6a196c534 100644 --- a/datafusion/physical-expr/src/crypto_expressions.rs +++ b/datafusion/physical-expr/src/crypto_expressions.rs @@ -25,7 +25,7 @@ use arrow::{ }; use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; -use datafusion_common::cast::as_generic_binary_array; +use datafusion_common::cast::{as_binary_array, as_generic_binary_array}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; @@ -284,15 +284,7 @@ pub fn md5(args: &[ColumnarValue]) -> Result { // md5 requires special handling because of its unique utf8 return type Ok(match value { ColumnarValue::Array(array) => { - let binary_array = array - .as_ref() - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "Impossibly got non-binary array data from digest".into(), - ) - })?; + let binary_array = as_binary_array(&array)?; let string_array: StringArray = binary_array .iter() .map(|opt| opt.map(hex_encode::<_>)) diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 02d9b9e826cf2..1f54cc672b8c2 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -26,17 +26,17 @@ use arrow::{ }, }; use arrow::{ - array::{ - Date64Array, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, - }, + array::{Date64Array, TimestampNanosecondArray}, compute::kernels::temporal, datatypes::TimeUnit, temporal_conversions::timestamp_ns_to_datetime, }; use chrono::prelude::*; use chrono::Duration; -use datafusion_common::cast::as_date32_array; +use datafusion_common::cast::{ + as_date32_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, + as_timestamp_nanosecond_array, as_timestamp_second_array, +}; use datafusion_common::{DataFusionError, Result}; use datafusion_common::{ScalarType, ScalarValue}; use datafusion_expr::ColumnarValue; @@ -292,10 +292,7 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { )) } ColumnarValue::Array(array) => { - let array = array - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_timestamp_nanosecond_array(array)?; let array = array .iter() .map(f) @@ -384,10 +381,7 @@ pub fn date_bin(args: &[ColumnarValue]) -> Result { } ColumnarValue::Array(array) => match array.data_type() { DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let array = array - .as_any() - .downcast_ref::() - .unwrap() + let array = as_timestamp_nanosecond_array(array)? .iter() .map(f) .collect::(); @@ -423,31 +417,19 @@ macro_rules! extract_date_part { } DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => { - let array = $ARRAY - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_timestamp_second_array($ARRAY)?; Ok($FN(array)?) } TimeUnit::Millisecond => { - let array = $ARRAY - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_timestamp_millisecond_array($ARRAY)?; Ok($FN(array)?) } TimeUnit::Microsecond => { - let array = $ARRAY - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_timestamp_microsecond_array($ARRAY)?; Ok($FN(array)?) } TimeUnit::Nanosecond => { - let array = $ARRAY - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_timestamp_nanosecond_array($ARRAY)?; Ok($FN(array)?) } }, @@ -514,7 +496,10 @@ pub fn date_part(args: &[ColumnarValue]) -> Result { mod tests { use std::sync::Arc; - use arrow::array::{ArrayRef, Int64Array, IntervalDayTimeArray, StringBuilder}; + use arrow::array::{ + ArrayRef, Int64Array, IntervalDayTimeArray, StringBuilder, + TimestampMicrosecondArray, + }; use super::*; diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index 14e7072598f66..bb29709aa0463 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -17,17 +17,17 @@ use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; -use arrow::array::{ - Array, ArrayRef, Date64Array, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, -}; +use arrow::array::{Array, ArrayRef, Date64Array}; use arrow::compute::unary; use arrow::datatypes::{ DataType, Date32Type, Date64Type, Schema, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; use arrow::record_batch::RecordBatch; -use datafusion_common::cast::as_date32_array; +use datafusion_common::cast::{ + as_date32_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, + as_timestamp_nanosecond_array, as_timestamp_second_array, +}; use datafusion_common::scalar::{ date32_add, date64_add, microseconds_add, milliseconds_add, nanoseconds_add, seconds_add, @@ -200,20 +200,14 @@ pub fn evaluate_array( })) as ArrayRef } DataType::Timestamp(TimeUnit::Second, _) => { - let array = array - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_timestamp_second_array(&array)?; Arc::new(unary::( array, |ts_s| seconds_add(ts_s, scalar, sign).unwrap(), )) as ArrayRef } DataType::Timestamp(TimeUnit::Millisecond, _) => { - let array = array - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_timestamp_millisecond_array(&array)?; Arc::new( unary::( array, @@ -222,10 +216,7 @@ pub fn evaluate_array( ) as ArrayRef } DataType::Timestamp(TimeUnit::Microsecond, _) => { - let array = array - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_timestamp_microsecond_array(&array)?; Arc::new( unary::( array, @@ -234,10 +225,7 @@ pub fn evaluate_array( ) as ArrayRef } DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let array = array - .as_any() - .downcast_ref::() - .unwrap(); + let array = as_timestamp_nanosecond_array(&array)?; Arc::new( unary::( array, diff --git a/datafusion/row/src/writer.rs b/datafusion/row/src/writer.rs index 98f3ac4481e2d..02325c1d68c9d 100644 --- a/datafusion/row/src/writer.rs +++ b/datafusion/row/src/writer.rs @@ -22,7 +22,7 @@ use arrow::array::*; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util::{round_upto_power_of_2, set_bit_raw, unset_bit_raw}; -use datafusion_common::cast::{as_date32_array, as_string_array}; +use datafusion_common::cast::{as_binary_array, as_date32_array, as_string_array}; use datafusion_common::Result; use std::cmp::max; use std::sync::Arc; @@ -364,7 +364,7 @@ pub(crate) fn write_field_binary( col_idx: usize, row_idx: usize, ) { - let from = from.as_any().downcast_ref::().unwrap(); + let from = as_binary_array(from).unwrap(); let s = from.value(row_idx); let new_width = to.current_width() + s.len(); if new_width > to.data.len() {