diff --git a/cpp/submodules/parquet-testing b/cpp/submodules/parquet-testing index 46c9e977f58..40379b3c582 160000 --- a/cpp/submodules/parquet-testing +++ b/cpp/submodules/parquet-testing @@ -1 +1 @@ -Subproject commit 46c9e977f58f6c5ef1b81f782f3746b3656e5a8c +Subproject commit 40379b3c58298fd22589dec7e41748375b5a8e82 diff --git a/rust/arrow/src/ipc/convert.rs b/rust/arrow/src/ipc/convert.rs index 84d6b39dba9..5a4db8f2624 100644 --- a/rust/arrow/src/ipc/convert.rs +++ b/rust/arrow/src/ipc/convert.rs @@ -518,11 +518,12 @@ pub(crate) fn get_fb_field_type<'a: 'b, 'b>( ) } List(ref list_type) => { + let field_name = fbb.create_string("list"); // field schema requires name to be not None let inner_types = get_fb_field_type(list_type, fbb); let child = ipc::Field::create( fbb, &ipc::FieldArgs { - name: None, + name: Some(field_name), nullable: false, type_type: inner_types.0, type_: Some(inner_types.1), diff --git a/rust/arrow/src/record_batch.rs b/rust/arrow/src/record_batch.rs index 0eac961b14f..0109f9e0f3a 100644 --- a/rust/arrow/src/record_batch.rs +++ b/rust/arrow/src/record_batch.rs @@ -300,6 +300,25 @@ mod tests { .len(4) .add_buffer(Buffer::from([42, 28, 19, 31].to_byte_slice())) .build(); + + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7, 8].to_byte_slice())) + .build(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7], [8]] + let value_offsets = Buffer::from(&[0, 3, 6, 8, 9].to_byte_slice()); + + // Construct a list array from the above two + let list_data_type = DataType::List(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(4) + .add_buffer(value_offsets.clone()) + .add_child_data(value_data.clone()) + .build(); + let struct_array = StructArray::from(vec![ ( Field::new("b", DataType::Boolean, false), @@ -310,10 +329,14 @@ mod tests { Field::new("c", DataType::Int32, false), Arc::new(Int32Array::from(vec![42, 28, 19, 31])), ), + ( + Field::new("d", DataType::List(Box::new(DataType::Int32)), false), + Arc::new(ListArray::from(list_data.clone())), + ), ]); let batch = RecordBatch::from(&struct_array); - assert_eq!(2, batch.num_columns()); + assert_eq!(3, batch.num_columns()); assert_eq!(4, batch.num_rows()); assert_eq!( struct_array.data_type(), @@ -321,5 +344,96 @@ mod tests { ); assert_eq!(batch.column(0).data(), boolean_data); assert_eq!(batch.column(1).data(), int_data); + assert_eq!(batch.column(2).data(), list_data); + } + + #[test] + fn create_record_batch_with_list_column() { + let schema = Schema::new(vec![Field::new( + "a", + DataType::List(Box::new(DataType::Int32)), + false, + )]); + + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice()); + + // Construct a list array from the above two + let list_data_type = DataType::List(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_buffer(value_offsets.clone()) + .add_child_data(value_data.clone()) + .build(); + let a = ListArray::from(list_data); + + let record_batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); + + assert_eq!(3, record_batch.num_rows()); + assert_eq!(1, record_batch.num_columns()); + assert_eq!( + &DataType::List(Box::new(DataType::Int32)), + record_batch.schema().field(0).data_type() + ); + assert_eq!(3, record_batch.column(0).data().len()); + } + + #[test] + fn create_record_batch_with_list_column_nulls() { + let schema = Schema::new(vec![Field::new( + "a", + DataType::List(Box::new(DataType::Int32)), + false, + )]); + + let values_builder = PrimitiveBuilder::::new(10); + let mut builder = ListBuilder::new(values_builder); + + builder.values().append_null().unwrap(); + builder.values().append_null().unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + builder.append(true).unwrap(); + + // [[null, null], null, []] + let list_array = builder.finish(); + + let record_batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_array)]).unwrap(); + + assert_eq!(3, record_batch.num_rows()); + assert_eq!(1, record_batch.num_columns()); + assert_eq!( + &DataType::List(Box::new(DataType::Int32)), + record_batch.schema().field(0).data_type() + ); + assert_eq!(3, record_batch.column(0).data().len()); + + assert_eq!(false, record_batch.column(0).is_null(0)); + assert_eq!(true, record_batch.column(0).is_null(1)); + assert_eq!(false, record_batch.column(0).is_null(2)); + + let col_as_list_array = record_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(2, col_as_list_array.value(0).len()); + assert_eq!(0, col_as_list_array.value(2).len()); + + let sublist_0_val = col_as_list_array.value(0); + let sublist_0 = sublist_0_val.as_any().downcast_ref::().unwrap(); + + assert_eq!(true, sublist_0.is_null(0)); + assert_eq!(true, sublist_0.is_null(1)); } } diff --git a/rust/arrow/src/util/pretty.rs b/rust/arrow/src/util/pretty.rs index 4d3c64408a7..f7d784e4289 100644 --- a/rust/arrow/src/util/pretty.rs +++ b/rust/arrow/src/util/pretty.rs @@ -79,6 +79,22 @@ macro_rules! make_string { }}; } +macro_rules! make_string_from_list { + ($column: ident, $row: ident) => {{ + let list = $column + .as_any() + .downcast_ref::() + .ok_or(ArrowError::InvalidArgumentError(format!( + "Repl error: could not convert list column to list array." + )))? + .value($row); + let string_values = (0..list.len()) + .map(|i| array_value_to_string(list.clone(), i)) + .collect::>>()?; + Ok(format!("[{}]", string_values.join(", "))) + }}; +} + /// Get the value at the given row in an array as a string fn array_value_to_string(column: array::ArrayRef, row: usize) -> Result { match column.data_type() { @@ -125,6 +141,7 @@ fn array_value_to_string(column: array::ArrayRef, row: usize) -> Result DataType::Time64(unit) if *unit == TimeUnit::Nanosecond => { make_string!(array::Time64NanosecondArray, column, row) } + DataType::List(_) => make_string_from_list!(column, row), _ => Err(ArrowError::InvalidArgumentError(format!( "Unsupported {:?} type for repl.", column.data_type() diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index 9969eeddf10..08ce7e62695 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -39,9 +39,9 @@ use arrow::compute; use arrow::compute::kernels::arithmetic::{add, divide, multiply, subtract}; use arrow::compute::kernels::boolean::{and, or}; use arrow::compute::kernels::cast::cast; -use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow::compute::kernels::comparison::{ - eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8, nlike_utf8, + eq, eq_utf8, gt, gt_eq, gt_eq_utf8, gt_utf8, like_utf8, lt, lt_eq, lt_eq_utf8, + lt_utf8, neq, neq_utf8, nlike_utf8, }; use arrow::datatypes::{DataType, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 704c5482543..0170cf9969b 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -1006,8 +1006,8 @@ mod tests { .build()?; let expected = "Projection: #state, #total_salary\ - \n Aggregate: groupBy=[[#state]], aggr=[[SUM(#salary) AS total_salary]]\ - \n TableScan: employee.csv projection=Some([3, 4])"; + \n Aggregate: groupBy=[[#state]], aggr=[[SUM(#salary) AS total_salary]]\ + \n TableScan: employee.csv projection=Some([3, 4])"; assert_eq!(expected, format!("{:?}", plan)); diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index 10e60adf787..488484f0150 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -332,7 +332,7 @@ mod tests { .build()?; let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#0)]]\ - \n TableScan: test projection=Some([1])"; + \n TableScan: test projection=Some([1])"; assert_optimized_plan_eq(&plan, expected); @@ -348,7 +348,7 @@ mod tests { .build()?; let expected = "Aggregate: groupBy=[[#1]], aggr=[[MAX(#0)]]\ - \n TableScan: test projection=Some([1, 2])"; + \n TableScan: test projection=Some([1, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -365,8 +365,8 @@ mod tests { .build()?; let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#0)]]\ - \n Selection: #1\ - \n TableScan: test projection=Some([1, 2])"; + \n Selection: #1\ + \n TableScan: test projection=Some([1, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -385,7 +385,7 @@ mod tests { .build()?; let expected = "Projection: CAST(#0 AS Float64)\ - \n TableScan: test projection=Some([2])"; + \n TableScan: test projection=Some([2])"; assert_optimized_plan_eq(&projection, expected); @@ -405,7 +405,7 @@ mod tests { assert_fields_eq(&plan, vec!["a", "b"]); let expected = "Projection: #0, #1\ - \n TableScan: test projection=Some([0, 1])"; + \n TableScan: test projection=Some([0, 1])"; assert_optimized_plan_eq(&plan, expected); @@ -426,8 +426,8 @@ mod tests { assert_fields_eq(&plan, vec!["c", "a"]); let expected = "Limit: UInt32(5)\ - \n Projection: #1, #0\ - \n TableScan: test projection=Some([0, 2])"; + \n Projection: #1, #0\ + \n TableScan: test projection=Some([0, 2])"; assert_optimized_plan_eq(&plan, expected); diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index a51aa41e02a..ec552c36ed8 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -579,7 +579,7 @@ mod tests { quick_test( "SELECT * from person", "Projection: #0, #1, #2, #3, #4, #5, #6\ - \n TableScan: person projection=None", + \n TableScan: person projection=None", ); } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 44ddca2713a..21f6ef8a67b 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::convert::TryFrom; use std::env; use std::sync::Arc; @@ -22,7 +23,7 @@ extern crate arrow; extern crate datafusion; use arrow::array::*; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Int64Type, Schema}; use arrow::record_batch::RecordBatch; use datafusion::datasource::csv::CsvReadOptions; @@ -119,6 +120,100 @@ fn parquet_single_nan_schema() { } } +#[test] +fn parquet_list_columns() { + let mut ctx = ExecutionContext::new(); + let testdata = env::var("PARQUET_TEST_DATA").expect("PARQUET_TEST_DATA not defined"); + ctx.register_parquet( + "list_columns", + &format!("{}/list_columns.parquet", testdata), + ) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "int64_list", + DataType::List(Box::new(DataType::Int64)), + true, + ), + Field::new("utf8_list", DataType::List(Box::new(DataType::Utf8)), true), + ])); + + let sql = "SELECT int64_list, utf8_list FROM list_columns"; + let plan = ctx.create_logical_plan(&sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let plan = ctx.create_physical_plan(&plan, DEFAULT_BATCH_SIZE).unwrap(); + let results = ctx.collect(plan.as_ref()).unwrap(); + + // int64_list utf8_list + // 0 [1, 2, 3] [abc, efg, hij] + // 1 [None, 1] None + // 2 [4] [efg, None, hij, xyz] + + assert_eq!(1, results.len()); + let batch = &results[0]; + assert_eq!(3, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + assert_eq!(&schema, batch.schema()); + + let int_list_array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let utf8_list_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + int_list_array + .value(0) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(1), Some(2), Some(3),]) + ); + + assert_eq!( + utf8_list_array + .value(0) + .as_any() + .downcast_ref::() + .unwrap(), + &StringArray::try_from(vec![Some("abc"), Some("efg"), Some("hij"),]).unwrap() + ); + + assert_eq!( + int_list_array + .value(1) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![None, Some(1),]) + ); + + assert!(utf8_list_array.is_null(1)); + + assert_eq!( + int_list_array + .value(2) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(4),]) + ); + + let result = utf8_list_array.value(2); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.value(0), "efg"); + assert!(result.is_null(1)); + assert_eq!(result.value(2), "hij"); + assert_eq!(result.value(3), "xyz"); +} + #[test] fn csv_count_star() -> Result<()> { let mut ctx = ExecutionContext::new(); diff --git a/rust/parquet/src/arrow/array_reader.rs b/rust/parquet/src/arrow/array_reader.rs index 3601166f564..a0fa13594c0 100644 --- a/rust/parquet/src/arrow/array_reader.rs +++ b/rust/parquet/src/arrow/array_reader.rs @@ -25,11 +25,12 @@ use std::sync::Arc; use std::vec::Vec; use arrow::array::{ - ArrayDataBuilder, ArrayDataRef, ArrayRef, BooleanBufferBuilder, BufferBuilderTrait, - Int16BufferBuilder, StructArray, + Array, ArrayData, ArrayDataBuilder, ArrayDataRef, ArrayRef, BinaryArray, + BinaryBuilder, BooleanBufferBuilder, BufferBuilderTrait, FixedSizeBinaryArray, + FixedSizeBinaryBuilder, Int16BufferBuilder, ListArray, ListBuilder, PrimitiveArray, + PrimitiveBuilder, StringArray, StringBuilder, StructArray, }; use arrow::buffer::{Buffer, MutableBuffer}; -use arrow::datatypes::{DataType as ArrowType, Field, IntervalUnit, TimeUnit}; use crate::arrow::converter::{ BinaryArrayConverter, BinaryConverter, BoolConverter, BooleanArrayConverter, @@ -54,6 +55,29 @@ use crate::schema::types::{ ColumnDescPtr, ColumnDescriptor, ColumnPath, SchemaDescPtr, Type, TypePtr, }; use crate::schema::visitor::TypeVisitor; +use arrow::datatypes::{ + BooleanType as ArrowBooleanType, DataType as ArrowType, + Date32Type as ArrowDate32Type, Date64Type as ArrowDate64Type, + DurationMicrosecondType as ArrowDurationMicrosecondType, + DurationMillisecondType as ArrowDurationMillisecondType, + DurationNanosecondType as ArrowDurationNanosecondType, + DurationSecondType as ArrowDurationSecondType, Field, + Float32Type as ArrowFloat32Type, Float64Type as ArrowFloat64Type, + Int16Type as ArrowInt16Type, Int32Type as ArrowInt32Type, + Int64Type as ArrowInt64Type, Int8Type as ArrowInt8Type, IntervalUnit, + Time32MillisecondType as ArrowTime32MillisecondType, + Time32SecondType as ArrowTime32SecondType, + Time64MicrosecondType as ArrowTime64MicrosecondType, + Time64NanosecondType as ArrowTime64NanosecondType, TimeUnit, + TimeUnit as ArrowTimeUnit, TimestampMicrosecondType as ArrowTimestampMicrosecondType, + TimestampMillisecondType as ArrowTimestampMillisecondType, + TimestampNanosecondType as ArrowTimestampNanosecondType, + TimestampSecondType as ArrowTimestampSecondType, ToByteSlice, + UInt16Type as ArrowUInt16Type, UInt32Type as ArrowUInt32Type, + UInt64Type as ArrowUInt64Type, UInt8Type as ArrowUInt8Type, +}; + +use arrow::util::bit_util; use std::any::Any; /// Array reader reads parquet data into arrow array. @@ -426,6 +450,398 @@ where } } +/// Implementation of list array reader. +pub struct ListArrayReader { + item_reader: Box, + data_type: ArrowType, + item_type: ArrowType, + list_def_level: i16, + list_rep_level: i16, + def_level_buffer: Option, + rep_level_buffer: Option, +} + +impl ListArrayReader { + /// Construct list array reader. + pub fn new( + item_reader: Box, + data_type: ArrowType, + item_type: ArrowType, + def_level: i16, + rep_level: i16, + ) -> Self { + Self { + item_reader, + data_type, + item_type, + list_def_level: def_level, + list_rep_level: rep_level, + def_level_buffer: None, + rep_level_buffer: None, + } + } +} + +macro_rules! build_empty_list_array_with_primitive_items { + ($item_type:ident) => {{ + let values_builder = PrimitiveBuilder::<$item_type>::new(0); + let mut builder = ListBuilder::new(values_builder); + let empty_list_array = builder.finish(); + Ok(Arc::new(empty_list_array)) + }}; +} + +macro_rules! build_empty_list_array_with_non_primitive_items { + ($builder:ident) => {{ + let values_builder = $builder::new(0); + let mut builder = ListBuilder::new(values_builder); + let empty_list_array = builder.finish(); + Ok(Arc::new(empty_list_array)) + }}; +} + +fn build_empty_list_array(item_type: ArrowType) -> Result { + match item_type { + ArrowType::UInt8 => build_empty_list_array_with_primitive_items!(ArrowUInt8Type), + ArrowType::UInt16 => { + build_empty_list_array_with_primitive_items!(ArrowUInt16Type) + } + ArrowType::UInt32 => { + build_empty_list_array_with_primitive_items!(ArrowUInt32Type) + } + ArrowType::UInt64 => { + build_empty_list_array_with_primitive_items!(ArrowUInt64Type) + } + ArrowType::Int8 => build_empty_list_array_with_primitive_items!(ArrowInt8Type), + ArrowType::Int16 => build_empty_list_array_with_primitive_items!(ArrowInt16Type), + ArrowType::Int32 => build_empty_list_array_with_primitive_items!(ArrowInt32Type), + ArrowType::Int64 => build_empty_list_array_with_primitive_items!(ArrowInt64Type), + ArrowType::Float32 => { + build_empty_list_array_with_primitive_items!(ArrowFloat32Type) + } + ArrowType::Float64 => { + build_empty_list_array_with_primitive_items!(ArrowFloat64Type) + } + ArrowType::Boolean => { + build_empty_list_array_with_primitive_items!(ArrowBooleanType) + } + ArrowType::Date32(_) => { + build_empty_list_array_with_primitive_items!(ArrowDate32Type) + } + ArrowType::Date64(_) => { + build_empty_list_array_with_primitive_items!(ArrowDate64Type) + } + ArrowType::Time32(ArrowTimeUnit::Second) => { + build_empty_list_array_with_primitive_items!(ArrowTime32SecondType) + } + ArrowType::Time32(ArrowTimeUnit::Millisecond) => { + build_empty_list_array_with_primitive_items!(ArrowTime32MillisecondType) + } + ArrowType::Time64(ArrowTimeUnit::Microsecond) => { + build_empty_list_array_with_primitive_items!(ArrowTime64MicrosecondType) + } + ArrowType::Time64(ArrowTimeUnit::Nanosecond) => { + build_empty_list_array_with_primitive_items!(ArrowTime64NanosecondType) + } + ArrowType::Duration(ArrowTimeUnit::Second) => { + build_empty_list_array_with_primitive_items!(ArrowDurationSecondType) + } + ArrowType::Duration(ArrowTimeUnit::Millisecond) => { + build_empty_list_array_with_primitive_items!(ArrowDurationMillisecondType) + } + ArrowType::Duration(ArrowTimeUnit::Microsecond) => { + build_empty_list_array_with_primitive_items!(ArrowDurationMicrosecondType) + } + ArrowType::Duration(ArrowTimeUnit::Nanosecond) => { + build_empty_list_array_with_primitive_items!(ArrowDurationNanosecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Second, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampSecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Millisecond, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampMillisecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Microsecond, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampMicrosecondType) + } + ArrowType::Timestamp(ArrowTimeUnit::Nanosecond, _) => { + build_empty_list_array_with_primitive_items!(ArrowTimestampNanosecondType) + } + ArrowType::Utf8 => { + build_empty_list_array_with_non_primitive_items!(StringBuilder) + } + ArrowType::Binary => { + build_empty_list_array_with_non_primitive_items!(BinaryBuilder) + } + _ => Err(ParquetError::General(format!( + "ListArray of type List({:?}) is not supported by array_reader", + item_type + ))), + } +} + +macro_rules! remove_primitive_array_indices { + ($arr: expr, $item_type:ty, $indices:expr) => {{ + let array_data = match $arr.as_any().downcast_ref::>() { + Some(a) => a, + _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), + }; + let mut builder = PrimitiveBuilder::<$item_type>::new($arr.len()); + for i in 0..array_data.len() { + if !$indices.contains(&i) { + if array_data.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array_data.value(i))?; + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +macro_rules! remove_array_indices_custom_builder { + ($arr: expr, $array_type:ty, $item_builder:ident, $indices:expr) => {{ + let array_data = match $arr.as_any().downcast_ref::<$array_type>() { + Some(a) => a, + _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), + }; + let mut builder = $item_builder::new(array_data.len()); + + for i in 0..array_data.len() { + if !$indices.contains(&i) { + if array_data.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array_data.value(i))?; + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +macro_rules! remove_fixed_size_binary_array_indices { + ($arr: expr, $array_type:ty, $item_builder:ident, $indices:expr, $len:expr) => {{ + let array_data = match $arr.as_any().downcast_ref::<$array_type>() { + Some(a) => a, + _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), + }; + let mut builder = FixedSizeBinaryBuilder::new(array_data.len(), $len); + for i in 0..array_data.len() { + if !$indices.contains(&i) { + if array_data.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array_data.value(i))?; + } + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +fn remove_indices( + arr: ArrayRef, + item_type: ArrowType, + indices: Vec, +) -> Result { + match item_type { + ArrowType::UInt8 => remove_primitive_array_indices!(arr, ArrowUInt8Type, indices), + ArrowType::UInt16 => { + remove_primitive_array_indices!(arr, ArrowUInt16Type, indices) + } + ArrowType::UInt32 => { + remove_primitive_array_indices!(arr, ArrowUInt32Type, indices) + } + ArrowType::UInt64 => { + remove_primitive_array_indices!(arr, ArrowUInt64Type, indices) + } + ArrowType::Int8 => remove_primitive_array_indices!(arr, ArrowInt8Type, indices), + ArrowType::Int16 => remove_primitive_array_indices!(arr, ArrowInt16Type, indices), + ArrowType::Int32 => remove_primitive_array_indices!(arr, ArrowInt32Type, indices), + ArrowType::Int64 => remove_primitive_array_indices!(arr, ArrowInt64Type, indices), + ArrowType::Float32 => { + remove_primitive_array_indices!(arr, ArrowFloat32Type, indices) + } + ArrowType::Float64 => { + remove_primitive_array_indices!(arr, ArrowFloat64Type, indices) + } + ArrowType::Boolean => { + remove_primitive_array_indices!(arr, ArrowBooleanType, indices) + } + ArrowType::Date32(_) => { + remove_primitive_array_indices!(arr, ArrowDate32Type, indices) + } + ArrowType::Date64(_) => { + remove_primitive_array_indices!(arr, ArrowDate64Type, indices) + } + ArrowType::Time32(ArrowTimeUnit::Second) => { + remove_primitive_array_indices!(arr, ArrowTime32SecondType, indices) + } + ArrowType::Time32(ArrowTimeUnit::Millisecond) => { + remove_primitive_array_indices!(arr, ArrowTime32MillisecondType, indices) + } + ArrowType::Time64(ArrowTimeUnit::Microsecond) => { + remove_primitive_array_indices!(arr, ArrowTime64MicrosecondType, indices) + } + ArrowType::Time64(ArrowTimeUnit::Nanosecond) => { + remove_primitive_array_indices!(arr, ArrowTime64NanosecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Second) => { + remove_primitive_array_indices!(arr, ArrowDurationSecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Millisecond) => { + remove_primitive_array_indices!(arr, ArrowDurationMillisecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Microsecond) => { + remove_primitive_array_indices!(arr, ArrowDurationMicrosecondType, indices) + } + ArrowType::Duration(ArrowTimeUnit::Nanosecond) => { + remove_primitive_array_indices!(arr, ArrowDurationNanosecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Second, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampSecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Millisecond, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampMillisecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Microsecond, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampMicrosecondType, indices) + } + ArrowType::Timestamp(ArrowTimeUnit::Nanosecond, _) => { + remove_primitive_array_indices!(arr, ArrowTimestampNanosecondType, indices) + } + ArrowType::Utf8 => { + remove_array_indices_custom_builder!(arr, StringArray, StringBuilder, indices) + } + ArrowType::Binary => { + remove_array_indices_custom_builder!(arr, BinaryArray, BinaryBuilder, indices) + } + ArrowType::FixedSizeBinary(size) => remove_fixed_size_binary_array_indices!( + arr, + FixedSizeBinaryArray, + FixedSizeBinaryBuilder, + indices, + size + ), + _ => Err(ParquetError::General(format!( + "ListArray of type List({:?}) is not supported by array_reader", + item_type + ))), + } +} + +/// Implementation of ListArrayReader. Nested lists and lists of structs are not yet supported. +impl ArrayReader for ListArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns data type. + /// This must be a List. + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn next_batch(&mut self, batch_size: usize) -> Result { + let next_batch_array = self.item_reader.next_batch(batch_size)?; + let item_type = self.item_reader.get_data_type().clone(); + + if next_batch_array.len() == 0 { + return build_empty_list_array(item_type); + } + let def_levels = self + .item_reader + .get_def_levels() + .ok_or(ArrowError("item_reader def levels are None.".to_string()))?; + let rep_levels = self + .item_reader + .get_rep_levels() + .ok_or(ArrowError("item_reader rep levels are None.".to_string()))?; + + if !((def_levels.len() == rep_levels.len()) + && (rep_levels.len() == next_batch_array.len())) + { + return Err(ArrowError( + "Expected item_reader def_levels and rep_levels to be same length as batch".to_string(), + )); + } + + // Need to remove from the values array the nulls that represent null lists rather than null items + // null lists have def_level = 0 + let mut null_list_indices: Vec = Vec::new(); + for i in 0..def_levels.len() { + if def_levels[i] == 0 { + null_list_indices.push(i); + } + } + let batch_values = match null_list_indices.len() { + 0 => next_batch_array.clone(), + _ => remove_indices(next_batch_array.clone(), item_type, null_list_indices)?, + }; + + // null list has def_level = 0 + // empty list has def_level = 1 + // null item in a list has def_level = 2 + // non-null item has def_level = 3 + // first item in each list has rep_level = 0, subsequent items have rep_level = 1 + + let mut offsets = Vec::new(); + let mut cur_offset = 0; + for i in 0..rep_levels.len() { + if rep_levels[i] == (0 as i16) { + offsets.push(cur_offset) + } + if def_levels[i] > 0 { + cur_offset = cur_offset + 1; + } + } + offsets.push(cur_offset); + + let num_bytes = bit_util::ceil(offsets.len(), 8); + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + let null_slice = null_buf.data_mut(); + let mut list_index = 0; + for i in 0..rep_levels.len() { + if rep_levels[i] == (0 as i16) && def_levels[i] != (0 as i16) { + bit_util::set_bit(null_slice, list_index); + } + if rep_levels[i] == (0 as i16) { + list_index = list_index + 1; + } + } + let value_offsets = Buffer::from(&offsets.to_byte_slice()); + + // null list has def_level = 0 + let null_count = def_levels.iter().filter(|x| x == &&(0 as i16)).count(); + + let list_data = ArrayData::builder(self.get_data_type().clone()) + .len(offsets.len() - 1) + .add_buffer(value_offsets.clone()) + .add_child_data(batch_values.data()) + .null_bit_buffer(null_buf.freeze()) + .null_count(null_count) + .offset(next_batch_array.offset()) + .build(); + + let result_array = ListArray::from(list_data); + return Ok(Arc::new(result_array)); + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_level_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_level_buffer + .as_ref() + .map(|buf| unsafe { buf.typed_data() }) + } +} + /// Implementation of struct array reader. pub struct StructArrayReader { children: Vec>, @@ -675,8 +1091,6 @@ impl<'a> TypeVisitor>, &'a ArrayReaderBuilderContext for ArrayReaderBuilder { /// Build array reader for primitive type. - /// Currently we don't have a list reader implementation, so repeated type is not - /// supported yet. fn visit_primitive( &mut self, cur_type: TypePtr, @@ -761,17 +1175,69 @@ impl<'a> TypeVisitor>, &'a ArrayReaderBuilderContext )) } - /// Build array reader for list type. - /// Currently this is not supported. + /// Build array reader for list type. Nested lists and lists of structs are not yet supported. fn visit_list_with_item( &mut self, - _list_type: Rc, - _item_type: &Type, - _context: &'a ArrayReaderBuilderContext, + list_type: Rc, + item_type: Rc, + context: &'a ArrayReaderBuilderContext, ) -> Result>> { - Err(ArrowError( - "Reading parquet list array into arrow is not supported yet!".to_string(), - )) + let list_child = &list_type + .get_fields() + .first() + .ok_or(ArrowError("List field must have a child.".to_string()))?; + let mut new_context = context.clone(); + + new_context.path.append(vec![list_type.name().to_string()]); + + match list_type.get_basic_info().repetition() { + Repetition::REPEATED => { + new_context.def_level += 1; + new_context.rep_level += 1; + } + Repetition::OPTIONAL => { + new_context.def_level += 1; + } + _ => (), + } + + match list_child.get_basic_info().repetition() { + Repetition::REPEATED => { + new_context.def_level += 1; + new_context.rep_level += 1; + } + Repetition::OPTIONAL => { + new_context.def_level += 1; + } + _ => (), + } + + let item_reader = self + .dispatch(item_type.clone(), &new_context) + .unwrap() + .unwrap(); + + let item_reader_type = item_reader.get_data_type().clone(); + + match item_reader_type { + ArrowType::List(_) + | ArrowType::FixedSizeList(_, _) + | ArrowType::Struct(_) + | ArrowType::Dictionary(_, _) => Err(ArrowError(format!( + "reading List({:?}) into arrow not supported yet", + item_type + ))), + _ => { + let arrow_type = ArrowType::List(Box::new(item_reader_type.clone())); + Ok(Some(Box::new(ListArrayReader::new( + item_reader, + arrow_type, + item_reader_type, + new_context.def_level, + new_context.rep_level, + )))) + } + } } } @@ -933,6 +1399,10 @@ impl<'a> ArrayReaderBuilder { #[cfg(test)] mod tests { use super::*; + use crate::arrow::array_reader::{ + build_array_reader, ArrayReader, ListArrayReader, PrimitiveArrayReader, + StructArrayReader, + }; use crate::arrow::converter::Utf8Converter; use crate::basic::{Encoding, Type as PhysicalType}; use crate::column::page::Page; @@ -945,7 +1415,9 @@ mod tests { DataPageBuilder, DataPageBuilderImpl, InMemoryPageIterator, }; use crate::util::test_common::{get_test_file, make_pages}; - use arrow::array::{Array, ArrayRef, PrimitiveArray, StringArray, StructArray}; + use arrow::array::{ + Array, ArrayRef, ListArray, PrimitiveArray, StringArray, StructArray, + }; use arrow::datatypes::{ DataType as ArrowType, Field, Int32Type as ArrowInt32, TimestampMicrosecondType as ArrowTimestampMicrosecondType, @@ -1524,4 +1996,56 @@ mod tests { assert_eq!(array_reader.get_data_type(), &arrow_type); } + + #[test] + fn test_list_array_reader() { + // [[1, null, 2], null, [3, 4]] + let array = Arc::new(PrimitiveArray::::from(vec![ + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + ])); + let item_array_reader = InMemoryArrayReader::new( + ArrowType::Int32, + array.clone(), + Some(vec![3, 2, 3, 0, 3, 3]), + Some(vec![0, 1, 1, 0, 0, 1]), + ); + + let mut list_array_reader = ListArrayReader::new( + Box::new(item_array_reader), + ArrowType::List(Box::new(ArrowType::Int32)), + ArrowType::Int32, + 1, + 1, + ); + + let next_batch = list_array_reader.next_batch(1024).unwrap(); + let list_array = next_batch.as_any().downcast_ref::().unwrap(); + + assert_eq!(3, list_array.len()); + + assert_eq!( + list_array + .value(0) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(1), None, Some(2)]) + ); + + assert!(list_array.is_null(1)); + + assert_eq!( + list_array + .value(2) + .as_any() + .downcast_ref::>() + .unwrap(), + &PrimitiveArray::::from(vec![Some(3), Some(4)]) + ); + } } diff --git a/rust/parquet/src/schema/visitor.rs b/rust/parquet/src/schema/visitor.rs index 6970f9ed47a..0ed818e9ce2 100644 --- a/rust/parquet/src/schema/visitor.rs +++ b/rust/parquet/src/schema/visitor.rs @@ -50,7 +50,7 @@ pub trait TypeVisitor { { self.visit_list_with_item( list_type.clone(), - list_item, + list_item.clone(), context, ) } else { @@ -70,13 +70,13 @@ pub trait TypeVisitor { { self.visit_list_with_item( list_type.clone(), - fields.first().unwrap(), + fields.first().unwrap().clone(), context, ) } else { self.visit_list_with_item( list_type.clone(), - list_item, + list_item.clone(), context, ) } @@ -98,6 +98,7 @@ pub trait TypeVisitor { /// A utility method which detects input type and calls corresponding method. fn dispatch(&mut self, cur_type: TypePtr, context: C) -> Result { if cur_type.is_primitive() { + println!("visiting primitive"); self.visit_primitive(cur_type, context) } else { match cur_type.get_basic_info().logical_type() { @@ -114,7 +115,7 @@ pub trait TypeVisitor { fn visit_list_with_item( &mut self, list_type: TypePtr, - item_type: &Type, + item_type: TypePtr, context: C, ) -> Result; } @@ -174,7 +175,7 @@ mod tests { fn visit_list_with_item( &mut self, list_type: TypePtr, - item_type: &Type, + item_type: TypePtr, _context: TestVisitorContext, ) -> Result { assert_eq!(