Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 248 additions & 9 deletions rust/parquet/src/arrow/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
//! The interfaces for converting arrow schema to parquet schema is coming.

use std::collections::{HashMap, HashSet};
use std::rc::Rc;

use crate::basic::{LogicalType, Repetition, Type as PhysicalType};
use crate::errors::{ParquetError::ArrowError, Result};
use crate::schema::types::{ColumnDescriptor, SchemaDescriptor, Type};

use crate::file::metadata::KeyValue;
use crate::schema::types::{ColumnDescriptor, SchemaDescriptor, Type, TypePtr};

use arrow::datatypes::TimeUnit;
use arrow::datatypes::{DataType, DateUnit, Field, Schema};

Expand Down Expand Up @@ -82,6 +83,19 @@ where
.map(|fields| Schema::new_with_metadata(fields, metadata))
}

/// Convert arrow schema to parquet schema
pub fn arrow_to_parquet_schema(schema: &Schema) -> Result<SchemaDescriptor> {
let fields: Result<Vec<TypePtr>> = schema
.fields()
.iter()
.map(|field| arrow_to_parquet_type(field).map(|f| Rc::new(f)))
.collect();
let group = Type::group_type_builder("arrow_schema")
.with_fields(&mut fields?)
.build()?;
Ok(SchemaDescriptor::new(Rc::new(group)))
}

fn parse_key_value_metadata(
key_value_metadata: &Option<Vec<KeyValue>>,
) -> Option<HashMap<String, String>> {
Expand Down Expand Up @@ -118,6 +132,143 @@ pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result<Field
.map(|opt| opt.unwrap())
}

/// Convert an arrow field to a parquet `Type`
fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
let name = field.name().as_str();
let repetition = if field.is_nullable() {
Repetition::OPTIONAL
} else {
Repetition::REQUIRED
};
// create type from field
match field.data_type() {
DataType::Boolean => Type::primitive_type_builder(name, PhysicalType::BOOLEAN)
.with_repetition(repetition)
.build(),
DataType::Int8 => Type::primitive_type_builder(name, PhysicalType::INT32)
.with_logical_type(LogicalType::INT_8)
.with_repetition(repetition)
.build(),
DataType::Int16 => Type::primitive_type_builder(name, PhysicalType::INT32)
.with_logical_type(LogicalType::INT_16)
.with_repetition(repetition)
.build(),
DataType::Int32 => Type::primitive_type_builder(name, PhysicalType::INT32)
.with_repetition(repetition)
.build(),
DataType::Int64 => Type::primitive_type_builder(name, PhysicalType::INT64)
.with_repetition(repetition)
.build(),
DataType::UInt8 => Type::primitive_type_builder(name, PhysicalType::INT32)
.with_logical_type(LogicalType::UINT_8)
.with_repetition(repetition)
.build(),
DataType::UInt16 => Type::primitive_type_builder(name, PhysicalType::INT32)
.with_logical_type(LogicalType::UINT_16)
.with_repetition(repetition)
.build(),
DataType::UInt32 => Type::primitive_type_builder(name, PhysicalType::INT32)
.with_logical_type(LogicalType::UINT_32)
.with_repetition(repetition)
.build(),
DataType::UInt64 => Type::primitive_type_builder(name, PhysicalType::INT64)
.with_logical_type(LogicalType::UINT_64)
.with_repetition(repetition)
.build(),
DataType::Float16 => Err(ArrowError("Float16 arrays not supported".to_string())),
DataType::Float32 => Type::primitive_type_builder(name, PhysicalType::FLOAT)
.with_repetition(repetition)
.build(),
DataType::Float64 => Type::primitive_type_builder(name, PhysicalType::DOUBLE)
.with_repetition(repetition)
.build(),
DataType::Timestamp(time_unit, _) => {
Type::primitive_type_builder(name, PhysicalType::INT64)
.with_logical_type(match time_unit {
TimeUnit::Second => LogicalType::TIMESTAMP_MILLIS,
TimeUnit::Millisecond => LogicalType::TIMESTAMP_MILLIS,
TimeUnit::Microsecond => LogicalType::TIMESTAMP_MICROS,
TimeUnit::Nanosecond => LogicalType::TIMESTAMP_MICROS,
})
.with_repetition(repetition)
.build()
}
DataType::Date32(_) => Type::primitive_type_builder(name, PhysicalType::INT32)
.with_logical_type(LogicalType::DATE)
.with_repetition(repetition)
.build(),
DataType::Date64(_) => Type::primitive_type_builder(name, PhysicalType::INT32)
.with_logical_type(LogicalType::DATE)
.with_repetition(repetition)
.build(),
DataType::Time32(_) => Type::primitive_type_builder(name, PhysicalType::INT32)
.with_logical_type(LogicalType::TIME_MILLIS)
.with_repetition(repetition)
.build(),
DataType::Time64(_) => Type::primitive_type_builder(name, PhysicalType::INT64)
.with_logical_type(LogicalType::TIME_MICROS)
.with_repetition(repetition)
.build(),
DataType::Duration(_) => Err(ArrowError(
"Converting Duration to parquet not supported".to_string(),
)),
DataType::Interval(_) => {
Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY)
.with_logical_type(LogicalType::INTERVAL)
.with_repetition(repetition)
.with_length(3)
.build()
}
DataType::Binary => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY)
.with_repetition(repetition)
.build(),
DataType::FixedSizeBinary(length) => {
Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY)
.with_repetition(repetition)
.with_length(*length)
.build()
}
DataType::Utf8 => Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY)
.with_logical_type(LogicalType::UTF8)
.with_repetition(repetition)
.build(),
DataType::List(dtype) | DataType::FixedSizeList(dtype, _) => {
Type::group_type_builder(name)
.with_fields(&mut vec![Rc::new(
Type::group_type_builder("list")
.with_fields(&mut vec![Rc::new({
let list_field = Field::new(
"element",
*dtype.clone(),
field.is_nullable(),
);
arrow_to_parquet_type(&list_field)?
})])
.with_repetition(Repetition::REPEATED)
.build()?,
)])
.with_logical_type(LogicalType::LIST)
.with_repetition(Repetition::REQUIRED)
.build()
}
DataType::Struct(fields) => {
// recursively convert children to types/nodes
let fields: Result<Vec<TypePtr>> = fields
.into_iter()
.map(|f| arrow_to_parquet_type(f).map(Rc::new))
.collect();
Type::group_type_builder(name)
.with_fields(&mut fields?)
.with_repetition(repetition)
.build()
}
DataType::Dictionary(_, ref value) => {
// Dictionary encoding not handled at the schema level
let dict_field = Field::new(name, *value.clone(), field.is_nullable());
arrow_to_parquet_type(&dict_field)
}
}
}
/// This struct is used to group methods and data structures used to convert parquet
/// schema together.
struct ParquetTypeConverter<'a> {
Expand Down Expand Up @@ -387,18 +538,14 @@ impl ParquetTypeConverter<'_> {

#[cfg(test)]
mod tests {
use std::rc::Rc;
use super::*;

use crate::schema::{parser::parse_message_type, types::SchemaDescriptor};
use std::collections::HashMap;

use arrow::datatypes::{DataType, DateUnit, Field, TimeUnit};

use super::{
parquet_to_arrow_field, parquet_to_arrow_schema,
parquet_to_arrow_schema_by_columns,
};
use crate::file::metadata::KeyValue;
use std::collections::HashMap;
use crate::schema::{parser::parse_message_type, types::SchemaDescriptor};

#[test]
fn test_flat_primitives() {
Expand Down Expand Up @@ -918,6 +1065,98 @@ mod tests {
assert_eq!(arrow_fields, converted_arrow_fields);
}

#[test]
fn test_field_to_column_desc() {
let message_type = "
message arrow_schema {
REQUIRED BOOLEAN boolean;
REQUIRED INT32 int8 (INT_8);
REQUIRED INT32 int16 (INT_16);
REQUIRED INT32 int32;
REQUIRED INT64 int64;
OPTIONAL DOUBLE double;
OPTIONAL FLOAT float;
OPTIONAL BINARY string (UTF8);
REQUIRED GROUP bools (LIST) {
REPEATED GROUP list {
OPTIONAL BOOLEAN element;
}
}
OPTIONAL INT32 date (DATE);
OPTIONAL INT32 time_milli (TIME_MILLIS);
OPTIONAL INT64 time_micro (TIME_MICROS);
OPTIONAL INT64 ts_milli (TIMESTAMP_MILLIS);
REQUIRED INT64 ts_micro (TIMESTAMP_MICROS);
REQUIRED GROUP struct {
REQUIRED BOOLEAN bools;
REQUIRED INT32 uint32 (UINT_32);
REQUIRED GROUP int32 (LIST) {
REPEATED GROUP list {
OPTIONAL INT32 element;
}
}
}
REQUIRED BINARY dictionary_strings (UTF8);
}
";
let parquet_group_type = parse_message_type(message_type).unwrap();

let parquet_schema = SchemaDescriptor::new(Rc::new(parquet_group_type));

let arrow_fields = vec![
Field::new("boolean", DataType::Boolean, false),
Field::new("int8", DataType::Int8, false),
Field::new("int16", DataType::Int16, false),
Field::new("int32", DataType::Int32, false),
Field::new("int64", DataType::Int64, false),
Field::new("double", DataType::Float64, true),
Field::new("float", DataType::Float32, true),
Field::new("string", DataType::Utf8, true),
Field::new("bools", DataType::List(Box::new(DataType::Boolean)), true),
Field::new("date", DataType::Date32(DateUnit::Day), true),
Field::new("time_milli", DataType::Time32(TimeUnit::Millisecond), true),
Field::new("time_micro", DataType::Time64(TimeUnit::Microsecond), true),
Field::new(
"ts_milli",
DataType::Timestamp(TimeUnit::Millisecond, None),
true,
),
Field::new(
"ts_micro",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
Field::new(
"struct",
DataType::Struct(vec![
Field::new("bools", DataType::Boolean, false),
Field::new("uint32", DataType::UInt32, false),
Field::new("int32", DataType::List(Box::new(DataType::Int32)), true),
]),
false,
),
Field::new(
"dictionary_strings",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
false,
),
];
let arrow_schema = Schema::new(arrow_fields);
let converted_arrow_schema = arrow_to_parquet_schema(&arrow_schema).unwrap();

assert_eq!(
parquet_schema.columns().len(),
converted_arrow_schema.columns().len()
);
parquet_schema
.columns()
.iter()
.zip(converted_arrow_schema.columns())
.for_each(|(a, b)| {
assert_eq!(a, b);
});
}

#[test]
fn test_metadata() {
let message_type = "
Expand Down
1 change: 1 addition & 0 deletions rust/parquet/src/schema/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ impl AsRef<[String]> for ColumnPath {
/// A descriptor for leaf-level primitive columns.
/// This encapsulates information such as definition and repetition levels and is used to
/// re-assemble nested data.
#[derive(Debug, PartialEq)]
pub struct ColumnDescriptor {
// The "leaf" primitive type of this column
primitive_type: TypePtr,
Expand Down