diff --git a/rust/parquet/src/arrow/schema.rs b/rust/parquet/src/arrow/schema.rs index 8a65964d958..4c390e13299 100644 --- a/rust/parquet/src/arrow/schema.rs +++ b/rust/parquet/src/arrow/schema.rs @@ -24,12 +24,13 @@ //! The interfaces for converting arrow schema to parquet schema is coming. use std::collections::{HashMap, HashSet}; +use std::rc::Rc; use crate::basic::{LogicalType, Repetition, Type as PhysicalType}; use crate::errors::{ParquetError::ArrowError, Result}; -use crate::schema::types::{ColumnDescriptor, SchemaDescriptor, Type}; - use crate::file::metadata::KeyValue; +use crate::schema::types::{ColumnDescriptor, SchemaDescriptor, Type, TypePtr}; + use arrow::datatypes::TimeUnit; use arrow::datatypes::{DataType, DateUnit, Field, Schema}; @@ -82,6 +83,19 @@ where .map(|fields| Schema::new_with_metadata(fields, metadata)) } +/// Convert arrow schema to parquet schema +pub fn arrow_to_parquet_schema(schema: &Schema) -> Result { + let fields: Result> = schema + .fields() + .iter() + .map(|field| arrow_to_parquet_type(field).map(|f| Rc::new(f))) + .collect(); + let group = Type::group_type_builder("arrow_schema") + .with_fields(&mut fields?) + .build()?; + Ok(SchemaDescriptor::new(Rc::new(group))) +} + fn parse_key_value_metadata( key_value_metadata: &Option>, ) -> Option> { @@ -118,6 +132,143 @@ pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result Result { + let name = field.name().as_str(); + let repetition = if field.is_nullable() { + Repetition::OPTIONAL + } else { + Repetition::REQUIRED + }; + // create type from field + match field.data_type() { + DataType::Boolean => Type::primitive_type_builder(name, PhysicalType::BOOLEAN) + .with_repetition(repetition) + .build(), + DataType::Int8 => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(LogicalType::INT_8) + .with_repetition(repetition) + .build(), + DataType::Int16 => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(LogicalType::INT_16) + .with_repetition(repetition) + .build(), + DataType::Int32 => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_repetition(repetition) + .build(), + DataType::Int64 => Type::primitive_type_builder(name, PhysicalType::INT64) + .with_repetition(repetition) + .build(), + DataType::UInt8 => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(LogicalType::UINT_8) + .with_repetition(repetition) + .build(), + DataType::UInt16 => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(LogicalType::UINT_16) + .with_repetition(repetition) + .build(), + DataType::UInt32 => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(LogicalType::UINT_32) + .with_repetition(repetition) + .build(), + DataType::UInt64 => Type::primitive_type_builder(name, PhysicalType::INT64) + .with_logical_type(LogicalType::UINT_64) + .with_repetition(repetition) + .build(), + DataType::Float16 => Err(ArrowError("Float16 arrays not supported".to_string())), + DataType::Float32 => Type::primitive_type_builder(name, PhysicalType::FLOAT) + .with_repetition(repetition) + .build(), + DataType::Float64 => Type::primitive_type_builder(name, PhysicalType::DOUBLE) + .with_repetition(repetition) + .build(), + DataType::Timestamp(time_unit, _) => { + Type::primitive_type_builder(name, PhysicalType::INT64) + .with_logical_type(match time_unit { + TimeUnit::Second => LogicalType::TIMESTAMP_MILLIS, + TimeUnit::Millisecond => LogicalType::TIMESTAMP_MILLIS, + TimeUnit::Microsecond => LogicalType::TIMESTAMP_MICROS, + TimeUnit::Nanosecond => LogicalType::TIMESTAMP_MICROS, + }) + .with_repetition(repetition) + .build() + } + DataType::Date32(_) => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(LogicalType::DATE) + .with_repetition(repetition) + .build(), + DataType::Date64(_) => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(LogicalType::DATE) + .with_repetition(repetition) + .build(), + DataType::Time32(_) => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(LogicalType::TIME_MILLIS) + .with_repetition(repetition) + .build(), + DataType::Time64(_) => Type::primitive_type_builder(name, PhysicalType::INT64) + .with_logical_type(LogicalType::TIME_MICROS) + .with_repetition(repetition) + .build(), + DataType::Duration(_) => Err(ArrowError( + "Converting Duration to parquet not supported".to_string(), + )), + DataType::Interval(_) => { + Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_logical_type(LogicalType::INTERVAL) + .with_repetition(repetition) + .with_length(3) + .build() + } + DataType::Binary => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) + .with_repetition(repetition) + .build(), + DataType::FixedSizeBinary(length) => { + Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) + .with_repetition(repetition) + .with_length(*length) + .build() + } + DataType::Utf8 => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) + .with_logical_type(LogicalType::UTF8) + .with_repetition(repetition) + .build(), + DataType::List(dtype) | DataType::FixedSizeList(dtype, _) => { + Type::group_type_builder(name) + .with_fields(&mut vec![Rc::new( + Type::group_type_builder("list") + .with_fields(&mut vec![Rc::new({ + let list_field = Field::new( + "element", + *dtype.clone(), + field.is_nullable(), + ); + arrow_to_parquet_type(&list_field)? + })]) + .with_repetition(Repetition::REPEATED) + .build()?, + )]) + .with_logical_type(LogicalType::LIST) + .with_repetition(Repetition::REQUIRED) + .build() + } + DataType::Struct(fields) => { + // recursively convert children to types/nodes + let fields: Result> = fields + .into_iter() + .map(|f| arrow_to_parquet_type(f).map(Rc::new)) + .collect(); + Type::group_type_builder(name) + .with_fields(&mut fields?) + .with_repetition(repetition) + .build() + } + DataType::Dictionary(_, ref value) => { + // Dictionary encoding not handled at the schema level + let dict_field = Field::new(name, *value.clone(), field.is_nullable()); + arrow_to_parquet_type(&dict_field) + } + } +} /// This struct is used to group methods and data structures used to convert parquet /// schema together. struct ParquetTypeConverter<'a> { @@ -387,18 +538,14 @@ impl ParquetTypeConverter<'_> { #[cfg(test)] mod tests { - use std::rc::Rc; + use super::*; - use crate::schema::{parser::parse_message_type, types::SchemaDescriptor}; + use std::collections::HashMap; use arrow::datatypes::{DataType, DateUnit, Field, TimeUnit}; - use super::{ - parquet_to_arrow_field, parquet_to_arrow_schema, - parquet_to_arrow_schema_by_columns, - }; use crate::file::metadata::KeyValue; - use std::collections::HashMap; + use crate::schema::{parser::parse_message_type, types::SchemaDescriptor}; #[test] fn test_flat_primitives() { @@ -918,6 +1065,98 @@ mod tests { assert_eq!(arrow_fields, converted_arrow_fields); } + #[test] + fn test_field_to_column_desc() { + let message_type = " + message arrow_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + REQUIRED INT32 int16 (INT_16); + REQUIRED INT32 int32; + REQUIRED INT64 int64; + OPTIONAL DOUBLE double; + OPTIONAL FLOAT float; + OPTIONAL BINARY string (UTF8); + REQUIRED GROUP bools (LIST) { + REPEATED GROUP list { + OPTIONAL BOOLEAN element; + } + } + OPTIONAL INT32 date (DATE); + OPTIONAL INT32 time_milli (TIME_MILLIS); + OPTIONAL INT64 time_micro (TIME_MICROS); + OPTIONAL INT64 ts_milli (TIMESTAMP_MILLIS); + REQUIRED INT64 ts_micro (TIMESTAMP_MICROS); + REQUIRED GROUP struct { + REQUIRED BOOLEAN bools; + REQUIRED INT32 uint32 (UINT_32); + REQUIRED GROUP int32 (LIST) { + REPEATED GROUP list { + OPTIONAL INT32 element; + } + } + } + REQUIRED BINARY dictionary_strings (UTF8); + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + + let parquet_schema = SchemaDescriptor::new(Rc::new(parquet_group_type)); + + let arrow_fields = vec![ + Field::new("boolean", DataType::Boolean, false), + Field::new("int8", DataType::Int8, false), + Field::new("int16", DataType::Int16, false), + Field::new("int32", DataType::Int32, false), + Field::new("int64", DataType::Int64, false), + Field::new("double", DataType::Float64, true), + Field::new("float", DataType::Float32, true), + Field::new("string", DataType::Utf8, true), + Field::new("bools", DataType::List(Box::new(DataType::Boolean)), true), + Field::new("date", DataType::Date32(DateUnit::Day), true), + Field::new("time_milli", DataType::Time32(TimeUnit::Millisecond), true), + Field::new("time_micro", DataType::Time64(TimeUnit::Microsecond), true), + Field::new( + "ts_milli", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts_micro", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "struct", + DataType::Struct(vec![ + Field::new("bools", DataType::Boolean, false), + Field::new("uint32", DataType::UInt32, false), + Field::new("int32", DataType::List(Box::new(DataType::Int32)), true), + ]), + false, + ), + Field::new( + "dictionary_strings", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), + ]; + let arrow_schema = Schema::new(arrow_fields); + let converted_arrow_schema = arrow_to_parquet_schema(&arrow_schema).unwrap(); + + assert_eq!( + parquet_schema.columns().len(), + converted_arrow_schema.columns().len() + ); + parquet_schema + .columns() + .iter() + .zip(converted_arrow_schema.columns()) + .for_each(|(a, b)| { + assert_eq!(a, b); + }); + } + #[test] fn test_metadata() { let message_type = " diff --git a/rust/parquet/src/schema/types.rs b/rust/parquet/src/schema/types.rs index c8bfd9c94f8..e1227c283dc 100644 --- a/rust/parquet/src/schema/types.rs +++ b/rust/parquet/src/schema/types.rs @@ -587,6 +587,7 @@ impl AsRef<[String]> for ColumnPath { /// A descriptor for leaf-level primitive columns. /// This encapsulates information such as definition and repetition levels and is used to /// re-assemble nested data. +#[derive(Debug, PartialEq)] pub struct ColumnDescriptor { // The "leaf" primitive type of this column primitive_type: TypePtr,