From 07d8ce82e1900e13cd25378b424fc16f996e4bb8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Jan 2021 09:28:52 -0500 Subject: [PATCH] ARROW-11327: [Rust][DataFusion] Add DictionarySupport to create_batch_empty --- rust/datafusion/src/physical_plan/common.rs | 311 +++++++++++--------- 1 file changed, 168 insertions(+), 143 deletions(-) diff --git a/rust/datafusion/src/physical_plan/common.rs b/rust/datafusion/src/physical_plan/common.rs index 60ca857e99b..6e3f23385ff 100644 --- a/rust/datafusion/src/physical_plan/common.rs +++ b/rust/datafusion/src/physical_plan/common.rs @@ -31,8 +31,15 @@ use array::{ Time32MillisecondArray, Time32SecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; -use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; +use arrow::{ + array::PrimitiveBuilder, + datatypes::{ + ArrowPrimitiveType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, + }, + error::Result as ArrowResult, +}; use arrow::{ array::{self, ArrayRef}, datatypes::Schema, @@ -121,130 +128,138 @@ pub fn build_file_list(dir: &str, filenames: &mut Vec, ext: &str) -> Res Ok(()) } -/// creates an empty record batch. +/// Creates an empty (0 row) record batch with the specified schema pub fn create_batch_empty(schema: &Schema) -> ArrowResult { let columns = schema .fields() .iter() - .map(|f| match f.data_type() { - DataType::Float32 => { - Ok(Arc::new(Float32Array::from(vec![] as Vec)) as ArrayRef) - } - DataType::Float64 => { - Ok(Arc::new(Float64Array::from(vec![] as Vec)) as ArrayRef) - } - DataType::Int64 => { - Ok(Arc::new(Int64Array::from(vec![] as Vec)) as ArrayRef) - } - DataType::Int32 => { - Ok(Arc::new(Int32Array::from(vec![] as Vec)) as ArrayRef) - } - DataType::Int16 => { - Ok(Arc::new(Int16Array::from(vec![] as Vec)) as ArrayRef) - } - DataType::Int8 => { - Ok(Arc::new(Int8Array::from(vec![] as Vec)) as ArrayRef) - } - DataType::UInt64 => { - Ok(Arc::new(UInt64Array::from(vec![] as Vec)) as ArrayRef) - } - DataType::UInt32 => { - Ok(Arc::new(UInt32Array::from(vec![] as Vec)) as ArrayRef) - } - DataType::UInt16 => { - Ok(Arc::new(UInt16Array::from(vec![] as Vec)) as ArrayRef) - } - DataType::UInt8 => { - Ok(Arc::new(UInt8Array::from(vec![] as Vec)) as ArrayRef) - } - DataType::Utf8 => { - Ok(Arc::new(StringArray::from(vec![] as Vec<&str>)) as ArrayRef) + .map(|f| create_empty_array(f.data_type())) + .collect::>() + .map_err(DataFusionError::into_arrow_external_error)?; + + RecordBatch::try_new(Arc::new(schema.to_owned()), columns) +} + +fn create_empty_array(data_type: &DataType) -> Result { + match data_type { + DataType::Float32 => { + Ok(Arc::new(Float32Array::from(vec![] as Vec)) as ArrayRef) + } + DataType::Float64 => { + Ok(Arc::new(Float64Array::from(vec![] as Vec)) as ArrayRef) + } + DataType::Int64 => Ok(Arc::new(Int64Array::from(vec![] as Vec)) as ArrayRef), + DataType::Int32 => Ok(Arc::new(Int32Array::from(vec![] as Vec)) as ArrayRef), + DataType::Int16 => Ok(Arc::new(Int16Array::from(vec![] as Vec)) as ArrayRef), + DataType::Int8 => Ok(Arc::new(Int8Array::from(vec![] as Vec)) as ArrayRef), + DataType::UInt64 => { + Ok(Arc::new(UInt64Array::from(vec![] as Vec)) as ArrayRef) + } + DataType::UInt32 => { + Ok(Arc::new(UInt32Array::from(vec![] as Vec)) as ArrayRef) + } + DataType::UInt16 => { + Ok(Arc::new(UInt16Array::from(vec![] as Vec)) as ArrayRef) + } + DataType::UInt8 => Ok(Arc::new(UInt8Array::from(vec![] as Vec)) as ArrayRef), + DataType::Utf8 => { + Ok(Arc::new(StringArray::from(vec![] as Vec<&str>)) as ArrayRef) + } + DataType::LargeUtf8 => { + Ok(Arc::new(LargeStringArray::from(vec![] as Vec<&str>)) as ArrayRef) + } + DataType::Boolean => { + Ok(Arc::new(BooleanArray::from(vec![] as Vec)) as ArrayRef) + } + DataType::Decimal(scale, precision) => { + let array_data = ArrayData::builder(DataType::Decimal(*scale, *precision)) + .len(0) + .add_buffer(Buffer::from(&[])) + .build(); + Ok(Arc::new(DecimalArray::from(array_data)) as ArrayRef) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Ok(Arc::new( + TimestampNanosecondArray::from_vec(vec![] as Vec, tz.clone()), + ) as ArrayRef), + DataType::Timestamp(TimeUnit::Microsecond, tz) => Ok(Arc::new( + TimestampMicrosecondArray::from_vec(vec![] as Vec, tz.clone()), + ) as ArrayRef), + DataType::Timestamp(TimeUnit::Millisecond, tz) => Ok(Arc::new( + TimestampMillisecondArray::from_vec(vec![] as Vec, tz.clone()), + ) as ArrayRef), + DataType::Timestamp(TimeUnit::Second, tz) => Ok(Arc::new( + TimestampSecondArray::from_vec(vec![] as Vec, tz.clone()), + ) as ArrayRef), + DataType::Date32(_) => { + Ok(Arc::new(Date32Array::from(vec![] as Vec)) as ArrayRef) + } + DataType::Date64(_) => { + Ok(Arc::new(Date64Array::from(vec![] as Vec)) as ArrayRef) + } + DataType::Time32(unit) => match unit { + TimeUnit::Second => { + Ok(Arc::new(Time32SecondArray::from(vec![] as Vec)) as ArrayRef) } - DataType::LargeUtf8 => { - Ok(Arc::new(LargeStringArray::from(vec![] as Vec<&str>)) as ArrayRef) + TimeUnit::Millisecond => { + Ok(Arc::new(Time32MillisecondArray::from(vec![] as Vec)) + as ArrayRef) } - DataType::Boolean => { - Ok(Arc::new(BooleanArray::from(vec![] as Vec)) as ArrayRef) + TimeUnit::Microsecond | TimeUnit::Nanosecond => { + Err(DataFusionError::NotImplemented(format!( + "Cannot convert datatype {:?} to array", + data_type + ))) } - DataType::Decimal(scale, precision) => { - let array_data = - ArrayData::builder(DataType::Decimal(*scale, *precision)) - .len(0) - .add_buffer(Buffer::from(&[])) - .build(); - - Ok(Arc::new(DecimalArray::from(array_data)) as ArrayRef) + }, + DataType::Time64(unit) => match unit { + TimeUnit::Second | TimeUnit::Millisecond => { + Err(DataFusionError::NotImplemented(format!( + "Cannot convert datatype {:?} to array", + data_type + ))) } - DataType::Timestamp(TimeUnit::Nanosecond, tz) => Ok(Arc::new( - TimestampNanosecondArray::from_vec(vec![] as Vec, tz.clone()), - ) - as ArrayRef), - DataType::Timestamp(TimeUnit::Microsecond, tz) => Ok(Arc::new( - TimestampMicrosecondArray::from_vec(vec![] as Vec, tz.clone()), - ) - as ArrayRef), - DataType::Timestamp(TimeUnit::Millisecond, tz) => Ok(Arc::new( - TimestampMillisecondArray::from_vec(vec![] as Vec, tz.clone()), - ) - as ArrayRef), - DataType::Timestamp(TimeUnit::Second, tz) => Ok(Arc::new( - TimestampSecondArray::from_vec(vec![] as Vec, tz.clone()), - ) as ArrayRef), - DataType::Date32(_) => { - Ok(Arc::new(Date32Array::from(vec![] as Vec)) as ArrayRef) + TimeUnit::Microsecond => { + Ok(Arc::new(Time64MicrosecondArray::from(vec![] as Vec)) + as ArrayRef) } - DataType::Date64(_) => { - Ok(Arc::new(Date64Array::from(vec![] as Vec)) as ArrayRef) + TimeUnit::Nanosecond => { + Ok(Arc::new(Time64NanosecondArray::from(vec![] as Vec)) as ArrayRef) } - DataType::Time32(unit) => match unit { - TimeUnit::Second => { - Ok(Arc::new(Time32SecondArray::from(vec![] as Vec)) as ArrayRef) - } - TimeUnit::Millisecond => { - Ok(Arc::new(Time32MillisecondArray::from(vec![] as Vec)) - as ArrayRef) - } - TimeUnit::Microsecond | TimeUnit::Nanosecond => { - Err(DataFusionError::NotImplemented(format!( - "Cannot convert datatype {:?} to array", - f.data_type() - ))) - } - }, - DataType::Time64(unit) => match unit { - TimeUnit::Second | TimeUnit::Millisecond => { - Err(DataFusionError::NotImplemented(format!( - "Cannot convert datatype {:?} to array", - f.data_type() - ))) - } - TimeUnit::Microsecond => { - Ok(Arc::new(Time64MicrosecondArray::from(vec![] as Vec)) - as ArrayRef) - } - TimeUnit::Nanosecond => { - Ok(Arc::new(Time64NanosecondArray::from(vec![] as Vec)) - as ArrayRef) - } - }, - DataType::List(nested_type) => Ok(build_empty_list_array::( - nested_type.data_type().clone(), - )?), - DataType::LargeList(nested_type) => Ok(build_empty_list_array::( - nested_type.data_type().clone(), - )?), - DataType::FixedSizeList(nested_type, _) => Ok( - build_empty_fixed_size_list_array(nested_type.data_type().clone())?, - ), - _ => Err(DataFusionError::NotImplemented(format!( - "Cannot convert datatype {:?} to array", - f.data_type() - ))), - }) - .collect::>() - .map_err(DataFusionError::into_arrow_external_error)?; + }, + DataType::List(nested_type) => Ok(build_empty_list_array::( + nested_type.data_type().clone(), + )?), + DataType::LargeList(nested_type) => Ok(build_empty_list_array::( + nested_type.data_type().clone(), + )?), + DataType::FixedSizeList(nested_type, _) => Ok(build_empty_fixed_size_list_array( + nested_type.data_type().clone(), + )?), + DataType::Dictionary(key_type, value_type) => match key_type.as_ref() { + DataType::UInt8 => build_empty_dictionary::(value_type), + DataType::UInt16 => build_empty_dictionary::(value_type), + DataType::UInt32 => build_empty_dictionary::(value_type), + DataType::UInt64 => build_empty_dictionary::(value_type), + DataType::Int8 => build_empty_dictionary::(value_type), + DataType::Int16 => build_empty_dictionary::(value_type), + DataType::Int32 => build_empty_dictionary::(value_type), + DataType::Int64 => build_empty_dictionary::(value_type), + _ => unreachable!(), + }, + _ => Err(DataFusionError::NotImplemented(format!( + "Creating empty array for type {:?} is not yet implemented", + data_type + ))), + } +} - RecordBatch::try_new(Arc::new(schema.to_owned()), columns) +fn build_empty_dictionary( + value_type: &DataType, +) -> Result { + let values: ArrayRef = create_empty_array(value_type)?; + let mut keys_builder: PrimitiveBuilder = PrimitiveBuilder::new(0); + let dict_array = keys_builder.finish_dict(values); + Ok(Arc::new(dict_array)) } #[cfg(test)] @@ -254,41 +269,51 @@ mod tests { #[test] fn test_create_batch_empty() { + use DataType::*; + let schema = Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::UInt32, false), - Field::new("c3", DataType::Int8, false), - Field::new("c4", DataType::Int16, false), - Field::new("c5", DataType::Int32, false), - Field::new("c6", DataType::Int64, false), - Field::new("c7", DataType::UInt8, false), - Field::new("c8", DataType::UInt16, false), - Field::new("c9", DataType::UInt32, false), - Field::new("c10", DataType::UInt64, false), - Field::new("c11", DataType::Float32, false), - Field::new("c12", DataType::Float64, false), - Field::new("c13", DataType::Utf8, false), - Field::new("c14", DataType::Decimal(10, 10), false), - Field::new("c15", DataType::Timestamp(TimeUnit::Second, None), false), + Field::new("c1", Utf8, false), + Field::new("c2", UInt32, false), + Field::new("c3", Int8, false), + Field::new("c4", Int16, false), + Field::new("c5", Int32, false), + Field::new("c6", Int64, false), + Field::new("c7", UInt8, false), + Field::new("c8", UInt16, false), + Field::new("c9", UInt32, false), + Field::new("c10", UInt64, false), + Field::new("c11", Float32, false), + Field::new("c12", Float64, false), + Field::new("c13", Utf8, false), + Field::new("c14", Decimal(10, 10), false), + Field::new("c15", Timestamp(TimeUnit::Second, None), false), + Field::new("c16", Timestamp(TimeUnit::Microsecond, None), false), + Field::new("c17", Timestamp(TimeUnit::Millisecond, None), false), + Field::new("c18", Timestamp(TimeUnit::Nanosecond, None), false), + Field::new("c19", Boolean, false), + Field::new("20", Dictionary(Box::new(UInt8), Box::new(Utf8)), false), + Field::new("21", Dictionary(Box::new(UInt16), Box::new(Utf8)), false), + Field::new("22", Dictionary(Box::new(UInt32), Box::new(Utf8)), false), + Field::new("23", Dictionary(Box::new(UInt64), Box::new(Utf8)), false), + Field::new("24", Dictionary(Box::new(Int8), Box::new(Utf8)), false), + Field::new("25", Dictionary(Box::new(Int16), Box::new(Utf8)), false), + Field::new("26", Dictionary(Box::new(Int32), Box::new(Utf8)), false), + Field::new("27", Dictionary(Box::new(Int64), Box::new(Utf8)), false), + // try non string dictionary + Field::new("28", Dictionary(Box::new(UInt8), Box::new(Int64)), false), Field::new( - "c16", - DataType::Timestamp(TimeUnit::Microsecond, None), + "29", + Dictionary(Box::new(UInt8), Box::new(LargeUtf8)), false, ), - Field::new( - "c17", - DataType::Timestamp(TimeUnit::Millisecond, None), - false, - ), - Field::new( - "c18", - DataType::Timestamp(TimeUnit::Nanosecond, None), - false, - ), - Field::new("c19", DataType::Boolean, false), ]); let batch = create_batch_empty(&schema).unwrap(); - assert_eq!(batch.columns().len(), 19); + assert_eq!(batch.columns().len(), 29); + assert_eq!(batch.num_rows(), 0); + + for (i, array) in batch.columns().iter().enumerate() { + assert_eq!(array.len(), 0, "Array[{}] was zero length", i); + } } }