From 6daee0494716b3ba2ff7b01525abb982ab141c55 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 May 2022 07:54:32 -0600 Subject: [PATCH 01/11] copy proto and source from Ballista --- datafusion/proto/proto/datafusion.proto | 219 +++- datafusion/proto/src/from_proto.rs | 13 + datafusion/proto/src/lib.rs | 1 + datafusion/proto/src/logical_plan.rs | 1523 +++++++++++++++++++++++ datafusion/proto/src/to_proto.rs | 25 +- 5 files changed, 1776 insertions(+), 5 deletions(-) create mode 100644 datafusion/proto/src/logical_plan.rs diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index a4b1863615ca0..c2ddfa1b96467 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -21,7 +21,7 @@ syntax = "proto3"; package datafusion; option java_multiple_files = true; -option java_package = "org.datafusioncompute.protobuf"; +option java_package = "org.apache.arrow.datafusion.protobuf"; option java_outer_classname = "DatafusionProto"; message ColumnRelation { @@ -43,6 +43,223 @@ message DfSchema { map metadata = 2; } +// logical plan +// LogicalPlan is a nested type +message LogicalPlanNode { + oneof LogicalPlanType { + ListingTableScanNode listing_scan = 1; + ProjectionNode projection = 3; + SelectionNode selection = 4; + LimitNode limit = 5; + AggregateNode aggregate = 6; + JoinNode join = 7; + SortNode sort = 8; + RepartitionNode repartition = 9; + EmptyRelationNode empty_relation = 10; + CreateExternalTableNode create_external_table = 11; + ExplainNode explain = 12; + WindowNode window = 13; + AnalyzeNode analyze = 14; + CrossJoinNode cross_join = 15; + ValuesNode values = 16; + LogicalExtensionNode extension = 17; + CreateCatalogSchemaNode create_catalog_schema = 18; + UnionNode union = 19; + CreateCatalogNode create_catalog = 20; + SubqueryAliasNode subquery_alias = 21; + CreateViewNode create_view = 22; + OffsetNode offset = 23; + } +} + +message LogicalExtensionNode { + bytes node = 1; + repeated LogicalPlanNode inputs = 2; +} + +message ProjectionColumns { + repeated string columns = 1; +} + +message CsvFormat { + bool has_header = 1; + string delimiter = 2; +} + +message ParquetFormat { + bool enable_pruning = 1; +} + +message AvroFormat {} + +message ListingTableScanNode { + string table_name = 1; + string path = 2; + string file_extension = 3; + ProjectionColumns projection = 4; + datafusion.Schema schema = 5; + repeated datafusion.LogicalExprNode filters = 6; + repeated string table_partition_cols = 7; + bool collect_stat = 8; + uint32 target_partitions = 9; + oneof FileFormatType { + CsvFormat csv = 10; + ParquetFormat parquet = 11; + AvroFormat avro = 12; + } +} + +message ProjectionNode { + LogicalPlanNode input = 1; + repeated datafusion.LogicalExprNode expr = 2; + oneof optional_alias { + string alias = 3; + } +} + +message SelectionNode { + LogicalPlanNode input = 1; + datafusion.LogicalExprNode expr = 2; +} + +message SortNode { + LogicalPlanNode input = 1; + repeated datafusion.LogicalExprNode expr = 2; +} + +message RepartitionNode { + LogicalPlanNode input = 1; + oneof partition_method { + uint64 round_robin = 2; + HashRepartition hash = 3; + } +} + +message HashRepartition { + repeated datafusion.LogicalExprNode hash_expr = 1; + uint64 partition_count = 2; +} + +message EmptyRelationNode { + bool produce_one_row = 1; +} + +message CreateExternalTableNode { + string name = 1; + string location = 2; + FileType file_type = 3; + bool has_header = 4; + datafusion.DfSchema schema = 5; + repeated string table_partition_cols = 6; + bool if_not_exists = 7; + string delimiter = 8; +} + +message CreateCatalogSchemaNode { + string schema_name = 1; + bool if_not_exists = 2; + datafusion.DfSchema schema = 3; +} + +message CreateCatalogNode { + string catalog_name = 1; + bool if_not_exists = 2; + datafusion.DfSchema schema = 3; +} + +message CreateViewNode { + string name = 1; + LogicalPlanNode input = 2; + bool or_replace = 3; +} + +// a node containing data for defining values list. unlike in SQL where it's two dimensional, here +// the list is flattened, and with the field n_cols it can be parsed and partitioned into rows +message ValuesNode { + uint64 n_cols = 1; + repeated datafusion.LogicalExprNode values_list = 2; +} + +enum FileType { + NdJson = 0; + Parquet = 1; + CSV = 2; + Avro = 3; +} + +message AnalyzeNode { + LogicalPlanNode input = 1; + bool verbose = 2; +} + +message ExplainNode { + LogicalPlanNode input = 1; + bool verbose = 2; +} + +message AggregateNode { + LogicalPlanNode input = 1; + repeated datafusion.LogicalExprNode group_expr = 2; + repeated datafusion.LogicalExprNode aggr_expr = 3; +} + +message WindowNode { + LogicalPlanNode input = 1; + repeated datafusion.LogicalExprNode window_expr = 2; +} + +enum JoinType { + INNER = 0; + LEFT = 1; + RIGHT = 2; + FULL = 3; + SEMI = 4; + ANTI = 5; +} + +enum JoinConstraint { + ON = 0; + USING = 1; +} + +message JoinNode { + LogicalPlanNode left = 1; + LogicalPlanNode right = 2; + JoinType join_type = 3; + JoinConstraint join_constraint = 4; + repeated datafusion.Column left_join_column = 5; + repeated datafusion.Column right_join_column = 6; + bool null_equals_null = 7; +} + +message UnionNode { + repeated LogicalPlanNode inputs = 1; +} + +message CrossJoinNode { + LogicalPlanNode left = 1; + LogicalPlanNode right = 2; +} + +message LimitNode { + LogicalPlanNode input = 1; + uint32 limit = 2; +} + +message OffsetNode { + LogicalPlanNode input = 1; + uint32 offset = 2; +} + +message SelectionExecNode { + datafusion.LogicalExprNode expr = 1; +} + +message SubqueryAliasNode { + LogicalPlanNode input = 1; + string alias = 2; +} + // logical expressions message LogicalExprNode { oneof ExprType { diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index fd72db4edac46..6ccd4025b1aff 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -187,6 +187,19 @@ impl TryFrom<&protobuf::DfField> for DFField { } } +#[allow(clippy::from_over_into)] +impl Into for protobuf::FileType { + fn into(self) -> datafusion::logical_plan::FileType { + use datafusion::logical_plan::FileType; + match self { + protobuf::FileType::NdJson => FileType::NdJson, + protobuf::FileType::Parquet => FileType::Parquet, + protobuf::FileType::Csv => FileType::CSV, + protobuf::FileType::Avro => FileType::Avro, + } + } +} + impl From for WindowFrameUnits { fn from(units: protobuf::WindowFrameUnits) -> Self { match units { diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 3b042da423b01..13fce4cbe3211 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -23,6 +23,7 @@ pub mod protobuf { pub mod bytes; pub mod from_proto; +pub mod logical_plan; pub mod to_proto; #[cfg(test)] diff --git a/datafusion/proto/src/logical_plan.rs b/datafusion/proto/src/logical_plan.rs new file mode 100644 index 0000000000000..22924226829aa --- /dev/null +++ b/datafusion/proto/src/logical_plan.rs @@ -0,0 +1,1523 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + from_proto::{self, parse_expr}, + protobuf::{ + self, listing_table_scan_node::FileFormatType, + logical_plan_node::LogicalPlanType, JoinConstraint, LogicalExtensionNode, + LogicalPlanNode, Schema, + }, + to_proto, +}; +use datafusion::prelude::SessionContext; +use datafusion::{ + datasource::{ + file_format::{ + avro::AvroFormat, csv::CsvFormat, parquet::ParquetFormat, FileFormat, + }, + listing::{ListingOptions, ListingTable, ListingTableConfig}, + }, + logical_plan::{provider_as_source, source_as_provider}, +}; +use datafusion_common::{Column, DataFusionError}; +use datafusion_expr::{ + logical_plan::{ + Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, + CrossJoin, EmptyRelation, Extension, Filter, Join, Limit, Offset, Projection, + Repartition, Sort, SubqueryAlias, TableScan, Values, Window, + }, + Expr, LogicalPlan, LogicalPlanBuilder, +}; +use prost::bytes::BufMut; +use prost::Message; +use std::fmt::Debug; +use std::sync::Arc; + +fn byte_to_string(b: u8) -> Result { + let b = &[b]; + let b = std::str::from_utf8(b) + .map_err(|_| DataFusionError::General("Invalid CSV delimiter".to_owned()))?; + Ok(b.to_owned()) +} + +fn str_to_byte(s: &str) -> Result { + if s.len() != 1 { + return Err(DataFusionError::General("Invalid CSV delimiter".to_owned())); + } + Ok(s.as_bytes()[0]) +} + +pub(crate) fn proto_error>(message: S) -> DataFusionError { + DataFusionError::General(message.into()) +} + +pub trait AsLogicalPlan: Debug + Send + Sync + Clone { + fn try_decode(buf: &[u8]) -> Result + where + Self: Sized; + + fn try_encode(&self, buf: &mut B) -> Result<(), DataFusionError> + where + B: BufMut, + Self: Sized; + + fn try_into_logical_plan( + &self, + ctx: &SessionContext, + extension_codec: &dyn LogicalExtensionCodec, + ) -> Result; + + fn try_from_logical_plan( + plan: &LogicalPlan, + extension_codec: &dyn LogicalExtensionCodec, + ) -> Result + where + Self: Sized; +} + +pub trait LogicalExtensionCodec: Debug + Send + Sync { + fn try_decode( + &self, + buf: &[u8], + inputs: &[LogicalPlan], + ctx: &SessionContext, + ) -> Result; + + fn try_encode( + &self, + node: &Extension, + buf: &mut Vec, + ) -> Result<(), DataFusionError>; +} + +#[derive(Debug, Clone)] +pub struct DefaultLogicalExtensionCodec {} + +impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &SessionContext, + ) -> Result { + Err(DataFusionError::NotImplemented( + "LogicalExtensionCodec is not provided".to_string(), + )) + } + + fn try_encode( + &self, + _node: &Extension, + _buf: &mut Vec, + ) -> Result<(), DataFusionError> { + Err(DataFusionError::NotImplemented( + "LogicalExtensionCodec is not provided".to_string(), + )) + } +} + +#[macro_export] +macro_rules! into_logical_plan { + ($PB:expr, $CTX:expr, $CODEC:expr) => {{ + if let Some(field) = $PB.as_ref() { + field.as_ref().try_into_logical_plan($CTX, $CODEC) + } else { + Err(proto_error("Missing required field in protobuf")) + } + }}; +} + +#[macro_export] +macro_rules! convert_required { + ($PB:expr) => {{ + if let Some(field) = $PB.as_ref() { + Ok(field.try_into()?) + } else { + Err(proto_error("Missing required field in protobuf")) + } + }}; +} + +#[macro_export] +macro_rules! into_required { + ($PB:expr) => {{ + if let Some(field) = $PB.as_ref() { + Ok(field.into()) + } else { + Err(proto_error("Missing required field in protobuf")) + } + }}; +} + +#[macro_export] +macro_rules! convert_box_required { + ($PB:expr) => {{ + if let Some(field) = $PB.as_ref() { + field.as_ref().try_into() + } else { + Err(proto_error("Missing required field in protobuf")) + } + }}; +} + +impl AsLogicalPlan for LogicalPlanNode { + fn try_decode(buf: &[u8]) -> Result + where + Self: Sized, + { + LogicalPlanNode::decode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to decode logical plan: {:?}", e)) + }) + } + + fn try_encode(&self, buf: &mut B) -> Result<(), DataFusionError> + where + B: BufMut, + Self: Sized, + { + self.encode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to encode logical plan: {:?}", e)) + }) + } + + fn try_into_logical_plan( + &self, + ctx: &SessionContext, + extension_codec: &dyn LogicalExtensionCodec, + ) -> Result { + let plan = self.logical_plan_type.as_ref().ok_or_else(|| { + proto_error(format!( + "logical_plan::from_proto() Unsupported logical plan '{:?}'", + self + )) + })?; + match plan { + LogicalPlanType::Values(values) => { + let n_cols = values.n_cols as usize; + let values: Vec> = + if values.values_list.is_empty() { + Ok(Vec::new()) + } else if values.values_list.len() % n_cols != 0 { + Err(DataFusionError::General(format!( + "Invalid values list length, expect {} to be divisible by {}", + values.values_list.len(), + n_cols + ))) + } else { + values + .values_list + .chunks_exact(n_cols) + .map(|r| { + r.iter() + .map(|expr| parse_expr(expr, ctx)) + .collect::, from_proto::Error>>() + }) + .collect::, _>>() + .map_err(|e| e.into()) + }?; + LogicalPlanBuilder::values(values)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::Projection(projection) => { + let input: LogicalPlan = + into_logical_plan!(projection.input, ctx, extension_codec)?; + let x: Vec = projection + .expr + .iter() + .map(|expr| parse_expr(expr, ctx)) + .collect::, _>>()?; + LogicalPlanBuilder::from(input) + .project_with_alias( + x, + projection.optional_alias.as_ref().map(|a| match a { + protobuf::projection_node::OptionalAlias::Alias(alias) => { + alias.clone() + } + }), + )? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::Selection(selection) => { + let input: LogicalPlan = + into_logical_plan!(selection.input, ctx, extension_codec)?; + let expr: Expr = selection + .expr + .as_ref() + .map(|expr| parse_expr(expr, ctx)) + .transpose()? + .ok_or_else(|| { + DataFusionError::General("expression required".to_string()) + })?; + // .try_into()?; + LogicalPlanBuilder::from(input) + .filter(expr)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::Window(window) => { + let input: LogicalPlan = + into_logical_plan!(window.input, ctx, extension_codec)?; + let window_expr = window + .window_expr + .iter() + .map(|expr| parse_expr(expr, ctx)) + .collect::, _>>()?; + LogicalPlanBuilder::from(input) + .window(window_expr)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::Aggregate(aggregate) => { + let input: LogicalPlan = + into_logical_plan!(aggregate.input, ctx, extension_codec)?; + let group_expr = aggregate + .group_expr + .iter() + .map(|expr| parse_expr(expr, ctx)) + .collect::, _>>()?; + let aggr_expr = aggregate + .aggr_expr + .iter() + .map(|expr| parse_expr(expr, ctx)) + .collect::, _>>()?; + LogicalPlanBuilder::from(input) + .aggregate(group_expr, aggr_expr)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::ListingScan(scan) => { + let schema: Schema = convert_required!(scan.schema)?; + + let mut projection = None; + if let Some(columns) = &scan.projection { + let column_indices = columns + .columns + .iter() + .map(|name| schema.index_of(name)) + .collect::, _>>()?; + projection = Some(column_indices); + } + + let filters = scan + .filters + .iter() + .map(|expr| parse_expr(expr, ctx)) + .collect::, _>>()?; + + let file_format: Arc = + match scan.file_format_type.as_ref().ok_or_else(|| { + proto_error(format!( + "logical_plan::from_proto() Unsupported file format '{:?}'", + self + )) + })? { + &FileFormatType::Parquet(protobuf::ParquetFormat { + enable_pruning, + }) => Arc::new( + ParquetFormat::default().with_enable_pruning(enable_pruning), + ), + FileFormatType::Csv(protobuf::CsvFormat { + has_header, + delimiter, + }) => Arc::new( + CsvFormat::default() + .with_has_header(*has_header) + .with_delimiter(str_to_byte(delimiter)?), + ), + FileFormatType::Avro(..) => Arc::new(AvroFormat::default()), + }; + + let options = ListingOptions { + file_extension: scan.file_extension.clone(), + format: file_format, + table_partition_cols: scan.table_partition_cols.clone(), + collect_stat: scan.collect_stat, + target_partitions: scan.target_partitions as usize, + }; + + let object_store = ctx + .runtime_env() + .object_store(scan.path.as_str()) + .map_err(|e| { + DataFusionError::NotImplemented(format!( + "No object store is registered for path {}: {:?}", + scan.path, e + )) + })? + .0; + + println!( + "Found object store {:?} for path {}", + object_store, + scan.path.as_str() + ); + + let config = ListingTableConfig::new(object_store, scan.path.as_str()) + .with_listing_options(options) + .with_schema(Arc::new(schema)); + + let provider = ListingTable::try_new(config)?; + + LogicalPlanBuilder::scan_with_filters( + &scan.table_name, + provider_as_source(Arc::new(provider)), + projection, + filters, + )? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::Sort(sort) => { + let input: LogicalPlan = + into_logical_plan!(sort.input, ctx, extension_codec)?; + let sort_expr: Vec = sort + .expr + .iter() + .map(|expr| parse_expr(expr, ctx)) + .collect::, _>>()?; + LogicalPlanBuilder::from(input) + .sort(sort_expr)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::Repartition(repartition) => { + use datafusion::logical_plan::Partitioning; + let input: LogicalPlan = + into_logical_plan!(repartition.input, ctx, extension_codec)?; + use protobuf::repartition_node::PartitionMethod; + let pb_partition_method = repartition.partition_method.clone().ok_or_else(|| { + DataFusionError::General(String::from( + "Protobuf deserialization error, RepartitionNode was missing required field 'partition_method'", + )) + })?; + + let partitioning_scheme = match pb_partition_method { + PartitionMethod::Hash(protobuf::HashRepartition { + hash_expr: pb_hash_expr, + partition_count, + }) => Partitioning::Hash( + pb_hash_expr + .iter() + .map(|expr| parse_expr(expr, ctx)) + .collect::, _>>()?, + partition_count as usize, + ), + PartitionMethod::RoundRobin(partition_count) => { + Partitioning::RoundRobinBatch(partition_count as usize) + } + }; + + LogicalPlanBuilder::from(input) + .repartition(partitioning_scheme)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::EmptyRelation(empty_relation) => { + LogicalPlanBuilder::empty(empty_relation.produce_one_row) + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::CreateExternalTable(create_extern_table) => { + let pb_schema = (create_extern_table.schema.clone()).ok_or_else(|| { + DataFusionError::General(String::from( + "Protobuf deserialization error, CreateExternalTableNode was missing required field schema.", + )) + })?; + + let pb_file_type: protobuf::FileType = + create_extern_table.file_type.try_into()?; + + Ok(LogicalPlan::CreateExternalTable(CreateExternalTable { + schema: pb_schema.try_into()?, + name: create_extern_table.name.clone(), + location: create_extern_table.location.clone(), + file_type: pb_file_type.into(), + has_header: create_extern_table.has_header, + delimiter: create_extern_table.delimiter.chars().next().ok_or_else(|| { + DataFusionError::General(String::from("Protobuf deserialization error, unable to parse CSV delimiter")) + })?, + table_partition_cols: create_extern_table + .table_partition_cols + .clone(), + if_not_exists: create_extern_table.if_not_exists, + })) + } + LogicalPlanType::CreateView(create_view) => { + let plan = create_view + .input.clone().ok_or_else(|| DataFusionError::General(String::from( + "Protobuf deserialization error, CreateViewNode has invalid LogicalPlan input.", + )))? + .try_into_logical_plan(ctx, extension_codec)?; + + Ok(LogicalPlan::CreateView(CreateView { + name: create_view.name.clone(), + input: Arc::new(plan), + or_replace: create_view.or_replace, + })) + } + LogicalPlanType::CreateCatalogSchema(create_catalog_schema) => { + let pb_schema = (create_catalog_schema.schema.clone()).ok_or_else(|| { + DataFusionError::General(String::from( + "Protobuf deserialization error, CreateCatalogSchemaNode was missing required field schema.", + )) + })?; + + Ok(LogicalPlan::CreateCatalogSchema(CreateCatalogSchema { + schema_name: create_catalog_schema.schema_name.clone(), + if_not_exists: create_catalog_schema.if_not_exists, + schema: pb_schema.try_into()?, + })) + } + LogicalPlanType::CreateCatalog(create_catalog) => { + let pb_schema = (create_catalog.schema.clone()).ok_or_else(|| { + DataFusionError::General(String::from( + "Protobuf deserialization error, CreateCatalogNode was missing required field schema.", + )) + })?; + + Ok(LogicalPlan::CreateCatalog(CreateCatalog { + catalog_name: create_catalog.catalog_name.clone(), + if_not_exists: create_catalog.if_not_exists, + schema: pb_schema.try_into()?, + })) + } + LogicalPlanType::Analyze(analyze) => { + let input: LogicalPlan = + into_logical_plan!(analyze.input, ctx, extension_codec)?; + LogicalPlanBuilder::from(input) + .explain(analyze.verbose, true)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::Explain(explain) => { + let input: LogicalPlan = + into_logical_plan!(explain.input, ctx, extension_codec)?; + LogicalPlanBuilder::from(input) + .explain(explain.verbose, false)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::SubqueryAlias(aliased_relation) => { + let input: LogicalPlan = + into_logical_plan!(aliased_relation.input, ctx, extension_codec)?; + LogicalPlanBuilder::from(input) + .alias(&aliased_relation.alias)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::Limit(limit) => { + let input: LogicalPlan = + into_logical_plan!(limit.input, ctx, extension_codec)?; + LogicalPlanBuilder::from(input) + .limit(limit.limit as usize)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::Offset(offset) => { + let input: LogicalPlan = + into_logical_plan!(offset.input, ctx, extension_codec)?; + LogicalPlanBuilder::from(input) + .offset(offset.offset as usize)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::Join(join) => { + let left_keys: Vec = + join.left_join_column.iter().map(|i| i.into()).collect(); + let right_keys: Vec = + join.right_join_column.iter().map(|i| i.into()).collect(); + let join_type = + protobuf::JoinType::from_i32(join.join_type).ok_or_else(|| { + proto_error(format!( + "Received a JoinNode message with unknown JoinType {}", + join.join_type + )) + })?; + let join_constraint = protobuf::JoinConstraint::from_i32( + join.join_constraint, + ) + .ok_or_else(|| { + proto_error(format!( + "Received a JoinNode message with unknown JoinConstraint {}", + join.join_constraint + )) + })?; + + let builder = LogicalPlanBuilder::from(into_logical_plan!( + join.left, + ctx, + extension_codec + )?); + let builder = match join_constraint.into() { + JoinConstraint::On => builder.join( + &into_logical_plan!(join.right, ctx, extension_codec)?, + join_type.into(), + (left_keys, right_keys), + )?, + JoinConstraint::Using => builder.join_using( + &into_logical_plan!(join.right, ctx, extension_codec)?, + join_type.into(), + left_keys, + )?, + }; + + builder.build().map_err(|e| e.into()) + } + LogicalPlanType::Union(union) => { + let mut input_plans: Vec = union + .inputs + .iter() + .map(|i| i.try_into_logical_plan(ctx, extension_codec)) + .collect::>()?; + + if input_plans.len() < 2 { + return Err( DataFusionError::General(String::from( + "Protobuf deserialization error, Union was require at least two input.", + ))); + } + + let mut builder = LogicalPlanBuilder::from(input_plans.pop().unwrap()); + for plan in input_plans { + builder = builder.union(plan)?; + } + builder.build().map_err(|e| e.into()) + } + LogicalPlanType::CrossJoin(crossjoin) => { + let left = into_logical_plan!(crossjoin.left, ctx, extension_codec)?; + let right = into_logical_plan!(crossjoin.right, ctx, extension_codec)?; + + LogicalPlanBuilder::from(left) + .cross_join(&right)? + .build() + .map_err(|e| e.into()) + } + LogicalPlanType::Extension(LogicalExtensionNode { node, inputs }) => { + let input_plans: Vec = inputs + .iter() + .map(|i| i.try_into_logical_plan(ctx, extension_codec)) + .collect::>()?; + + let extension_node = + extension_codec.try_decode(node, &input_plans, ctx)?; + Ok(LogicalPlan::Extension(extension_node)) + } + } + } + + fn try_from_logical_plan( + plan: &LogicalPlan, + extension_codec: &dyn LogicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + match plan { + LogicalPlan::Values(Values { values, .. }) => { + let n_cols = if values.is_empty() { + 0 + } else { + values[0].len() + } as u64; + let values_list = values + .iter() + .flatten() + .map(|v| v.try_into()) + .collect::, _>>()?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Values( + protobuf::ValuesNode { + n_cols, + values_list, + }, + )), + }) + } + LogicalPlan::TableScan(TableScan { + table_name, + source, + filters, + projection, + .. + }) => { + let source = source_as_provider(source)?; + let schema = source.schema(); + let source = source.as_any(); + + let projection = match projection { + None => None, + Some(columns) => { + let column_names = columns + .iter() + .map(|i| schema.field(*i).name().to_owned()) + .collect(); + Some(protobuf::ProjectionColumns { + columns: column_names, + }) + } + }; + let schema: protobuf::Schema = schema.as_ref().into(); + + let filters: Vec = filters + .iter() + .map(|filter| filter.try_into()) + .collect::, _>>()?; + + if let Some(listing_table) = source.downcast_ref::() { + let any = listing_table.options().format.as_any(); + let file_format_type = if let Some(parquet) = + any.downcast_ref::() + { + FileFormatType::Parquet(protobuf::ParquetFormat { + enable_pruning: parquet.enable_pruning(), + }) + } else if let Some(csv) = any.downcast_ref::() { + FileFormatType::Csv(protobuf::CsvFormat { + delimiter: byte_to_string(csv.delimiter())?, + has_header: csv.has_header(), + }) + } else if any.is::() { + FileFormatType::Avro(protobuf::AvroFormat {}) + } else { + return Err(proto_error(format!( + "Error converting file format, {:?} is invalid as a datafusion foramt.", + listing_table.options().format + ))); + }; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::ListingScan( + protobuf::ListingTableScanNode { + file_format_type: Some(file_format_type), + table_name: table_name.to_owned(), + collect_stat: listing_table.options().collect_stat, + file_extension: listing_table + .options() + .file_extension + .clone(), + table_partition_cols: listing_table + .options() + .table_partition_cols + .clone(), + path: listing_table.table_path().to_owned(), + schema: Some(schema), + projection, + filters, + target_partitions: listing_table + .options() + .target_partitions + as u32, + }, + )), + }) + } else { + Err(DataFusionError::General(format!( + "logical plan to_proto unsupported table provider {:?}", + source + ))) + } + } + LogicalPlan::Projection(Projection { + expr, input, alias, .. + }) => Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Projection(Box::new( + protobuf::ProjectionNode { + input: Some(Box::new( + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?, + )), + expr: expr.iter().map(|expr| expr.try_into()).collect::, + to_proto::Error, + >>( + )?, + optional_alias: alias + .clone() + .map(protobuf::projection_node::OptionalAlias::Alias), + }, + ))), + }), + LogicalPlan::Filter(Filter { predicate, input }) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Selection(Box::new( + protobuf::SelectionNode { + input: Some(Box::new(input)), + expr: Some(predicate.try_into()?), + }, + ))), + }) + } + LogicalPlan::Window(Window { + input, window_expr, .. + }) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Window(Box::new( + protobuf::WindowNode { + input: Some(Box::new(input)), + window_expr: window_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + }, + ))), + }) + } + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + input, + .. + }) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Aggregate(Box::new( + protobuf::AggregateNode { + input: Some(Box::new(input)), + group_expr: group_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + aggr_expr: aggr_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + }, + ))), + }) + } + LogicalPlan::Join(Join { + left, + right, + on, + join_type, + join_constraint, + null_equals_null, + .. + }) => { + let left: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + left.as_ref(), + extension_codec, + )?; + let right: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + right.as_ref(), + extension_codec, + )?; + let (left_join_column, right_join_column) = + on.iter().map(|(l, r)| (l.into(), r.into())).unzip(); + let join_type: protobuf::JoinType = join_type.to_owned().into(); + let join_constraint: protobuf::JoinConstraint = + join_constraint.to_owned().into(); + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Join(Box::new( + protobuf::JoinNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + join_type: join_type.into(), + join_constraint: join_constraint.into(), + left_join_column, + right_join_column, + null_equals_null: *null_equals_null, + }, + ))), + }) + } + LogicalPlan::Subquery(_) => { + // note that the ballista and datafusion proto files need refactoring to allow + // LogicalExprNode to reference a LogicalPlanNode + // see https://github.com/apache/arrow-datafusion/issues/2338 + Err(DataFusionError::NotImplemented( + "Ballista does not support subqueries".to_string(), + )) + } + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::SubqueryAlias(Box::new( + protobuf::SubqueryAliasNode { + input: Some(Box::new(input)), + alias: alias.clone(), + }, + ))), + }) + } + LogicalPlan::Limit(Limit { input, n }) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Limit(Box::new( + protobuf::LimitNode { + input: Some(Box::new(input)), + limit: *n as u32, + }, + ))), + }) + } + LogicalPlan::Offset(Offset { input, offset }) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Offset(Box::new( + protobuf::OffsetNode { + input: Some(Box::new(input)), + offset: *offset as u32, + }, + ))), + }) + } + LogicalPlan::Sort(Sort { input, expr }) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + let selection_expr: Vec = expr + .iter() + .map(|expr| expr.try_into()) + .collect::, to_proto::Error>>()?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Sort(Box::new( + protobuf::SortNode { + input: Some(Box::new(input)), + expr: selection_expr, + }, + ))), + }) + } + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) => { + use datafusion::logical_plan::Partitioning; + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + + // Assumed common usize field was batch size + // Used u64 to avoid any nastyness involving large values, most data clusters are probably uniformly 64 bits any ways + use protobuf::repartition_node::PartitionMethod; + + let pb_partition_method = match partitioning_scheme { + Partitioning::Hash(exprs, partition_count) => { + PartitionMethod::Hash(protobuf::HashRepartition { + hash_expr: exprs + .iter() + .map(|expr| expr.try_into()) + .collect::, to_proto::Error>>()?, + partition_count: *partition_count as u64, + }) + } + Partitioning::RoundRobinBatch(partition_count) => { + PartitionMethod::RoundRobin(*partition_count as u64) + } + }; + + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Repartition(Box::new( + protobuf::RepartitionNode { + input: Some(Box::new(input)), + partition_method: Some(pb_partition_method), + }, + ))), + }) + } + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row, .. + }) => Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::EmptyRelation( + protobuf::EmptyRelationNode { + produce_one_row: *produce_one_row, + }, + )), + }), + LogicalPlan::CreateExternalTable(CreateExternalTable { + name, + location, + file_type, + has_header, + delimiter, + schema: df_schema, + table_partition_cols, + if_not_exists, + }) => { + use datafusion::logical_plan::FileType; + + let pb_file_type: protobuf::FileType = match file_type { + FileType::NdJson => protobuf::FileType::NdJson, + FileType::Parquet => protobuf::FileType::Parquet, + FileType::CSV => protobuf::FileType::Csv, + FileType::Avro => protobuf::FileType::Avro, + }; + + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CreateExternalTable( + protobuf::CreateExternalTableNode { + name: name.clone(), + location: location.clone(), + file_type: pb_file_type as i32, + has_header: *has_header, + schema: Some(df_schema.into()), + table_partition_cols: table_partition_cols.clone(), + if_not_exists: *if_not_exists, + delimiter: String::from(*delimiter), + }, + )), + }) + } + LogicalPlan::CreateView(CreateView { + name, + input, + or_replace, + }) => Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CreateView(Box::new( + protobuf::CreateViewNode { + name: name.clone(), + input: Some(Box::new(LogicalPlanNode::try_from_logical_plan( + input, + extension_codec, + )?)), + or_replace: *or_replace, + }, + ))), + }), + LogicalPlan::CreateCatalogSchema(CreateCatalogSchema { + schema_name, + if_not_exists, + schema: df_schema, + }) => Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CreateCatalogSchema( + protobuf::CreateCatalogSchemaNode { + schema_name: schema_name.clone(), + if_not_exists: *if_not_exists, + schema: Some(df_schema.into()), + }, + )), + }), + LogicalPlan::CreateCatalog(CreateCatalog { + catalog_name, + if_not_exists, + schema: df_schema, + }) => Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CreateCatalog( + protobuf::CreateCatalogNode { + catalog_name: catalog_name.clone(), + if_not_exists: *if_not_exists, + schema: Some(df_schema.into()), + }, + )), + }), + LogicalPlan::Analyze(a) => { + let input = protobuf::LogicalPlanNode::try_from_logical_plan( + a.input.as_ref(), + extension_codec, + )?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Analyze(Box::new( + protobuf::AnalyzeNode { + input: Some(Box::new(input)), + verbose: a.verbose, + }, + ))), + }) + } + LogicalPlan::Explain(a) => { + let input = protobuf::LogicalPlanNode::try_from_logical_plan( + a.plan.as_ref(), + extension_codec, + )?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Explain(Box::new( + protobuf::ExplainNode { + input: Some(Box::new(input)), + verbose: a.verbose, + }, + ))), + }) + } + LogicalPlan::Union(union) => { + let inputs: Vec = union + .inputs + .iter() + .map(|i| { + protobuf::LogicalPlanNode::try_from_logical_plan( + i, + extension_codec, + ) + }) + .collect::>()?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Union( + protobuf::UnionNode { inputs }, + )), + }) + } + LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { + let left = protobuf::LogicalPlanNode::try_from_logical_plan( + left.as_ref(), + extension_codec, + )?; + let right = protobuf::LogicalPlanNode::try_from_logical_plan( + right.as_ref(), + extension_codec, + )?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CrossJoin(Box::new( + protobuf::CrossJoinNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + }, + ))), + }) + } + LogicalPlan::Extension(extension) => { + let mut buf: Vec = vec![]; + extension_codec.try_encode(extension, &mut buf)?; + + let inputs: Vec = extension + .node + .inputs() + .iter() + .map(|i| { + protobuf::LogicalPlanNode::try_from_logical_plan( + i, + extension_codec, + ) + }) + .collect::>()?; + + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Extension( + LogicalExtensionNode { node: buf, inputs }, + )), + }) + } + LogicalPlan::CreateMemoryTable(_) => Err(proto_error( + "Error converting CreateMemoryTable. Not yet supported in Ballista", + )), + LogicalPlan::DropTable(_) => Err(proto_error( + "Error converting DropTable. Not yet supported in Ballista", + )), + } + } +} + +#[cfg(test)] +mod roundtrip_tests { + + use super::super::{super::error::Result, protobuf}; + use crate::serde::{AsLogicalPlan, BallistaCodec}; + use async_trait::async_trait; + use core::panic; + use datafusion::common::DFSchemaRef; + use datafusion::logical_plan::source_as_provider; + use datafusion::{ + arrow::datatypes::{DataType, Field, Schema}, + datafusion_data_access::{ + self, + object_store::{FileMetaStream, ListEntryStream, ObjectReader, ObjectStore}, + SizedFile, + }, + datasource::listing::ListingTable, + logical_plan::{ + col, CreateExternalTable, Expr, FileType, LogicalPlan, LogicalPlanBuilder, + Repartition, ToDFSchema, + }, + prelude::*, + }; + use std::io; + use std::sync::Arc; + + #[derive(Debug)] + struct TestObjectStore {} + + #[async_trait] + impl ObjectStore for TestObjectStore { + async fn list_file( + &self, + _prefix: &str, + ) -> datafusion_data_access::Result { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "this is only a test object store".to_string(), + )) + } + + async fn list_dir( + &self, + _prefix: &str, + _delimiter: Option, + ) -> datafusion_data_access::Result { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "this is only a test object store".to_string(), + )) + } + + fn file_reader( + &self, + _file: SizedFile, + ) -> datafusion_data_access::Result> { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "this is only a test object store".to_string(), + )) + } + } + + // Given a identity of a LogicalPlan converts it to protobuf and back, using debug formatting to test equality. + macro_rules! roundtrip_test { + ($initial_struct:ident, $proto_type:ty, $struct_type:ty) => { + let proto: $proto_type = (&$initial_struct).try_into()?; + + let round_trip: $struct_type = (&proto).try_into()?; + + assert_eq!( + format!("{:?}", $initial_struct), + format!("{:?}", round_trip) + ); + }; + ($initial_struct:ident, $struct_type:ty) => { + roundtrip_test!($initial_struct, protobuf::LogicalPlanNode, $struct_type); + }; + ($initial_struct:ident) => { + let ctx = SessionContext::new(); + let codec: BallistaCodec< + protobuf::LogicalPlanNode, + protobuf::PhysicalPlanNode, + > = BallistaCodec::default(); + let proto: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + &$initial_struct, + codec.logical_extension_codec(), + ) + .expect("from logical plan"); + let round_trip: LogicalPlan = proto + .try_into_logical_plan(&ctx, codec.logical_extension_codec()) + .expect("to logical plan"); + + assert_eq!( + format!("{:?}", $initial_struct), + format!("{:?}", round_trip) + ); + }; + ($initial_struct:ident, $ctx:ident) => { + let codec: BallistaCodec< + protobuf::LogicalPlanNode, + protobuf::PhysicalPlanNode, + > = BallistaCodec::default(); + let proto: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan(&$initial_struct) + .expect("from logical plan"); + let round_trip: LogicalPlan = proto + .try_into_logical_plan(&$ctx, codec.logical_extension_codec()) + .expect("to logical plan"); + + assert_eq!( + format!("{:?}", $initial_struct), + format!("{:?}", round_trip) + ); + }; + } + + #[tokio::test] + async fn roundtrip_repartition() -> Result<()> { + use datafusion::logical_plan::Partitioning; + + let test_partition_counts = [usize::MIN, usize::MAX, 43256]; + + let test_expr: Vec = + vec![col("c1") + col("c2"), Expr::Literal((4.0).into())]; + + let plan = std::sync::Arc::new( + test_scan_csv("employee.csv", Some(vec![3, 4])) + .await? + .sort(vec![col("salary")])? + .build()?, + ); + + for partition_count in test_partition_counts.iter() { + let rr_repartition = Partitioning::RoundRobinBatch(*partition_count); + + let roundtrip_plan = LogicalPlan::Repartition(Repartition { + input: plan.clone(), + partitioning_scheme: rr_repartition, + }); + + roundtrip_test!(roundtrip_plan); + + let h_repartition = Partitioning::Hash(test_expr.clone(), *partition_count); + + let roundtrip_plan = LogicalPlan::Repartition(Repartition { + input: plan.clone(), + partitioning_scheme: h_repartition, + }); + + roundtrip_test!(roundtrip_plan); + + let no_expr_hrepartition = Partitioning::Hash(Vec::new(), *partition_count); + + let roundtrip_plan = LogicalPlan::Repartition(Repartition { + input: plan.clone(), + partitioning_scheme: no_expr_hrepartition, + }); + + roundtrip_test!(roundtrip_plan); + } + + Ok(()) + } + + #[test] + fn roundtrip_create_external_table() -> Result<()> { + let schema = test_schema(); + + let df_schema_ref = schema.to_dfschema_ref()?; + + let filetypes: [FileType; 4] = [ + FileType::NdJson, + FileType::Parquet, + FileType::CSV, + FileType::Avro, + ]; + + for file in filetypes.iter() { + let create_table_node = + LogicalPlan::CreateExternalTable(CreateExternalTable { + schema: df_schema_ref.clone(), + name: String::from("TestName"), + location: String::from("employee.csv"), + file_type: *file, + has_header: true, + delimiter: ',', + table_partition_cols: vec![], + if_not_exists: false, + }); + + roundtrip_test!(create_table_node); + } + + Ok(()) + } + + #[tokio::test] + async fn roundtrip_analyze() -> Result<()> { + let verbose_plan = test_scan_csv("employee.csv", Some(vec![3, 4])) + .await? + .sort(vec![col("salary")])? + .explain(true, true)? + .build()?; + + let plan = test_scan_csv("employee.csv", Some(vec![3, 4])) + .await? + .sort(vec![col("salary")])? + .explain(false, true)? + .build()?; + + roundtrip_test!(plan); + + roundtrip_test!(verbose_plan); + + Ok(()) + } + + #[tokio::test] + async fn roundtrip_explain() -> Result<()> { + let verbose_plan = test_scan_csv("employee.csv", Some(vec![3, 4])) + .await? + .sort(vec![col("salary")])? + .explain(true, false)? + .build()?; + + let plan = test_scan_csv("employee.csv", Some(vec![3, 4])) + .await? + .sort(vec![col("salary")])? + .explain(false, false)? + .build()?; + + roundtrip_test!(plan); + + roundtrip_test!(verbose_plan); + + Ok(()) + } + + #[tokio::test] + async fn roundtrip_join() -> Result<()> { + let scan_plan = test_scan_csv("employee1", Some(vec![0, 3, 4])) + .await? + .build()?; + + let plan = test_scan_csv("employee2", Some(vec![0, 3, 4])) + .await? + .join(&scan_plan, JoinType::Inner, (vec!["id"], vec!["id"]))? + .build()?; + + roundtrip_test!(plan); + Ok(()) + } + + #[tokio::test] + async fn roundtrip_sort() -> Result<()> { + let plan = test_scan_csv("employee.csv", Some(vec![3, 4])) + .await? + .sort(vec![col("salary")])? + .build()?; + roundtrip_test!(plan); + + Ok(()) + } + + #[tokio::test] + async fn roundtrip_empty_relation() -> Result<()> { + let plan_false = LogicalPlanBuilder::empty(false).build()?; + + roundtrip_test!(plan_false); + + let plan_true = LogicalPlanBuilder::empty(true).build()?; + + roundtrip_test!(plan_true); + + Ok(()) + } + + #[tokio::test] + async fn roundtrip_logical_plan() -> Result<()> { + let plan = test_scan_csv("employee.csv", Some(vec![3, 4])) + .await? + .aggregate(vec![col("state")], vec![max(col("salary"))])? + .build()?; + + roundtrip_test!(plan); + + Ok(()) + } + + #[ignore] // see https://github.com/apache/arrow-datafusion/issues/2546 + #[tokio::test] + async fn roundtrip_logical_plan_custom_ctx() -> Result<()> { + let ctx = SessionContext::new(); + let codec: BallistaCodec = + BallistaCodec::default(); + let custom_object_store = Arc::new(TestObjectStore {}); + ctx.runtime_env() + .register_object_store("test", custom_object_store.clone()); + + let (os, uri) = ctx.runtime_env().object_store("test://foo.csv")?; + assert_eq!("TestObjectStore", &format!("{:?}", os)); + assert_eq!("foo.csv", uri); + + let schema = test_schema(); + let plan = ctx + .read_csv( + "test://employee.csv", + CsvReadOptions::new().schema(&schema).has_header(true), + ) + .await? + .to_logical_plan()?; + + let proto: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + &plan, + codec.logical_extension_codec(), + ) + .expect("from logical plan"); + let round_trip: LogicalPlan = proto + .try_into_logical_plan(&ctx, codec.logical_extension_codec()) + .expect("to logical plan"); + + assert_eq!(format!("{:?}", plan), format!("{:?}", round_trip)); + + let round_trip_store = match round_trip { + LogicalPlan::TableScan(scan) => { + let source = source_as_provider(&scan.source)?; + match source.as_ref().as_any().downcast_ref::() { + Some(listing_table) => { + format!("{:?}", listing_table.object_store()) + } + _ => panic!("expected a ListingTable"), + } + } + _ => panic!("expected a TableScan"), + }; + + assert_eq!(round_trip_store, format!("{:?}", custom_object_store)); + + Ok(()) + } + + fn test_schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new("state", DataType::Utf8, false), + Field::new("salary", DataType::Int32, false), + ]) + } + + async fn test_scan_csv( + table_name: &str, + projection: Option>, + ) -> Result { + let schema = test_schema(); + let ctx = SessionContext::new(); + let options = CsvReadOptions::new().schema(&schema); + let df = ctx.read_csv(table_name, options).await?; + let plan = match df.to_logical_plan()? { + LogicalPlan::TableScan(ref scan) => { + let mut scan = scan.clone(); + scan.projection = projection; + let mut projected_schema = scan.projected_schema.as_ref().clone(); + projected_schema = projected_schema.replace_qualifier(table_name); + scan.projected_schema = DFSchemaRef::new(projected_schema); + LogicalPlan::TableScan(scan) + } + _ => unimplemented!(), + }; + Ok(LogicalPlanBuilder::from(plan)) + } +} diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 7aa4278b39a49..f02a69c5f8b66 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -30,11 +30,11 @@ use crate::protobuf::{ use arrow::datatypes::{ DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode, }; -use datafusion_common::{Column, DFField, DFSchemaRef, ScalarValue}; +use datafusion_common::{Column, DFField, DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_expr::{ - logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, - BuiltInWindowFunction, BuiltinScalarFunction, Expr, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunction, + logical_plan::{PlanType, StringifiedPlan}, + AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, Expr, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunction, }; #[derive(Debug)] @@ -415,6 +415,23 @@ impl From for protobuf::WindowFrame { } } +impl TryFrom for protobuf::FileType { + type Error = DataFusionError; + fn try_from(value: i32) -> Result { + use protobuf::FileType; + match value { + _x if _x == FileType::NdJson as i32 => Ok(FileType::NdJson), + _x if _x == FileType::Parquet as i32 => Ok(FileType::Parquet), + _x if _x == FileType::Csv as i32 => Ok(FileType::Csv), + _x if _x == FileType::Avro as i32 => Ok(FileType::Avro), + invalid => Err(DataFusionError::General(format!( + "Attempted to convert invalid i32 to protobuf::Filetype: {}", + invalid + ))), + } + } +} + impl TryFrom<&Expr> for protobuf::LogicalExprNode { type Error = Error; From 9c7b1ef9caac99cb481223a9d2beb0d5332c17d7 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 May 2022 08:11:46 -0600 Subject: [PATCH 02/11] it compiles --- datafusion/proto/Cargo.toml | 3 + datafusion/proto/src/lib.rs | 14 + datafusion/proto/src/logical_plan.rs | 474 ++++----------------------- datafusion/proto/src/to_proto.rs | 2 +- 4 files changed, 88 insertions(+), 405 deletions(-) diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 58a69124dbf4a..e19517f57436d 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -36,10 +36,13 @@ path = "src/lib.rs" [dependencies] arrow = { version = "14.0.0" } +async-trait = "0.1" datafusion = { path = "../core", version = "8.0.0" } datafusion-common = { path = "../common", version = "8.0.0" } +datafusion-data-access = { path = "../data-access", version = "8.0.0" } datafusion-expr = { path = "../expr", version = "8.0.0" } prost = "0.10" +tokio = "1.18" [build-dependencies] tonic-build = { version = "0.7" } diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 13fce4cbe3211..dcf91ee7e3b5d 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use datafusion_common::DataFusionError; + // include the generated protobuf source as a submodule #[allow(clippy::all)] pub mod protobuf { @@ -26,6 +28,18 @@ pub mod from_proto; pub mod logical_plan; pub mod to_proto; +impl From for DataFusionError { + fn from(e: from_proto::Error) -> Self { + DataFusionError::Plan(e.to_string()) + } +} + +impl From for DataFusionError { + fn from(e: to_proto::Error) -> Self { + DataFusionError::Plan(e.to_string()) + } +} + #[cfg(test)] mod roundtrip_tests { use super::from_proto::parse_expr; diff --git a/datafusion/proto/src/logical_plan.rs b/datafusion/proto/src/logical_plan.rs index 22924226829aa..f2c4ec531b3d5 100644 --- a/datafusion/proto/src/logical_plan.rs +++ b/datafusion/proto/src/logical_plan.rs @@ -19,11 +19,11 @@ use crate::{ from_proto::{self, parse_expr}, protobuf::{ self, listing_table_scan_node::FileFormatType, - logical_plan_node::LogicalPlanType, JoinConstraint, LogicalExtensionNode, - LogicalPlanNode, Schema, + logical_plan_node::LogicalPlanType, LogicalExtensionNode, LogicalPlanNode, }, to_proto, }; +use arrow::datatypes::Schema; use datafusion::prelude::SessionContext; use datafusion::{ datasource::{ @@ -38,8 +38,9 @@ use datafusion_common::{Column, DataFusionError}; use datafusion_expr::{ logical_plan::{ Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, - CrossJoin, EmptyRelation, Extension, Filter, Join, Limit, Offset, Projection, - Repartition, Sort, SubqueryAlias, TableScan, Values, Window, + CrossJoin, EmptyRelation, Extension, Filter, Join, JoinConstraint, JoinType, + Limit, Offset, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, + Window, }, Expr, LogicalPlan, LogicalPlanBuilder, }; @@ -51,19 +52,21 @@ use std::sync::Arc; fn byte_to_string(b: u8) -> Result { let b = &[b]; let b = std::str::from_utf8(b) - .map_err(|_| DataFusionError::General("Invalid CSV delimiter".to_owned()))?; + .map_err(|_| DataFusionError::Internal("Invalid CSV delimiter".to_owned()))?; Ok(b.to_owned()) } fn str_to_byte(s: &str) -> Result { if s.len() != 1 { - return Err(DataFusionError::General("Invalid CSV delimiter".to_owned())); + return Err(DataFusionError::Internal( + "Invalid CSV delimiter".to_owned(), + )); } Ok(s.as_bytes()[0]) } pub(crate) fn proto_error>(message: S) -> DataFusionError { - DataFusionError::General(message.into()) + DataFusionError::Internal(message.into()) } pub trait AsLogicalPlan: Debug + Send + Sync + Clone { @@ -175,6 +178,50 @@ macro_rules! convert_box_required { }}; } +impl From for JoinType { + fn from(t: protobuf::JoinType) -> Self { + match t { + protobuf::JoinType::Inner => JoinType::Inner, + protobuf::JoinType::Left => JoinType::Left, + protobuf::JoinType::Right => JoinType::Right, + protobuf::JoinType::Full => JoinType::Full, + protobuf::JoinType::Semi => JoinType::Semi, + protobuf::JoinType::Anti => JoinType::Anti, + } + } +} + +impl From for protobuf::JoinType { + fn from(t: JoinType) -> Self { + match t { + JoinType::Inner => protobuf::JoinType::Inner, + JoinType::Left => protobuf::JoinType::Left, + JoinType::Right => protobuf::JoinType::Right, + JoinType::Full => protobuf::JoinType::Full, + JoinType::Semi => protobuf::JoinType::Semi, + JoinType::Anti => protobuf::JoinType::Anti, + } + } +} + +impl From for JoinConstraint { + fn from(t: protobuf::JoinConstraint) -> Self { + match t { + protobuf::JoinConstraint::On => JoinConstraint::On, + protobuf::JoinConstraint::Using => JoinConstraint::Using, + } + } +} + +impl From for protobuf::JoinConstraint { + fn from(t: JoinConstraint) -> Self { + match t { + JoinConstraint::On => protobuf::JoinConstraint::On, + JoinConstraint::Using => protobuf::JoinConstraint::Using, + } + } +} + impl AsLogicalPlan for LogicalPlanNode { fn try_decode(buf: &[u8]) -> Result where @@ -213,7 +260,7 @@ impl AsLogicalPlan for LogicalPlanNode { if values.values_list.is_empty() { Ok(Vec::new()) } else if values.values_list.len() % n_cols != 0 { - Err(DataFusionError::General(format!( + Err(DataFusionError::Internal(format!( "Invalid values list length, expect {} to be divisible by {}", values.values_list.len(), n_cols @@ -263,7 +310,7 @@ impl AsLogicalPlan for LogicalPlanNode { .map(|expr| parse_expr(expr, ctx)) .transpose()? .ok_or_else(|| { - DataFusionError::General("expression required".to_string()) + DataFusionError::Internal("expression required".to_string()) })?; // .try_into()?; LogicalPlanBuilder::from(input) @@ -403,7 +450,7 @@ impl AsLogicalPlan for LogicalPlanNode { into_logical_plan!(repartition.input, ctx, extension_codec)?; use protobuf::repartition_node::PartitionMethod; let pb_partition_method = repartition.partition_method.clone().ok_or_else(|| { - DataFusionError::General(String::from( + DataFusionError::Internal(String::from( "Protobuf deserialization error, RepartitionNode was missing required field 'partition_method'", )) })?; @@ -436,7 +483,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::CreateExternalTable(create_extern_table) => { let pb_schema = (create_extern_table.schema.clone()).ok_or_else(|| { - DataFusionError::General(String::from( + DataFusionError::Internal(String::from( "Protobuf deserialization error, CreateExternalTableNode was missing required field schema.", )) })?; @@ -451,7 +498,7 @@ impl AsLogicalPlan for LogicalPlanNode { file_type: pb_file_type.into(), has_header: create_extern_table.has_header, delimiter: create_extern_table.delimiter.chars().next().ok_or_else(|| { - DataFusionError::General(String::from("Protobuf deserialization error, unable to parse CSV delimiter")) + DataFusionError::Internal(String::from("Protobuf deserialization error, unable to parse CSV delimiter")) })?, table_partition_cols: create_extern_table .table_partition_cols @@ -461,7 +508,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::CreateView(create_view) => { let plan = create_view - .input.clone().ok_or_else(|| DataFusionError::General(String::from( + .input.clone().ok_or_else(|| DataFusionError::Internal(String::from( "Protobuf deserialization error, CreateViewNode has invalid LogicalPlan input.", )))? .try_into_logical_plan(ctx, extension_codec)?; @@ -474,7 +521,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::CreateCatalogSchema(create_catalog_schema) => { let pb_schema = (create_catalog_schema.schema.clone()).ok_or_else(|| { - DataFusionError::General(String::from( + DataFusionError::Internal(String::from( "Protobuf deserialization error, CreateCatalogSchemaNode was missing required field schema.", )) })?; @@ -487,7 +534,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::CreateCatalog(create_catalog) => { let pb_schema = (create_catalog.schema.clone()).ok_or_else(|| { - DataFusionError::General(String::from( + DataFusionError::Internal(String::from( "Protobuf deserialization error, CreateCatalogNode was missing required field schema.", )) })?; @@ -570,6 +617,7 @@ impl AsLogicalPlan for LogicalPlanNode { &into_logical_plan!(join.right, ctx, extension_codec)?, join_type.into(), (left_keys, right_keys), + None, // filter )?, JoinConstraint::Using => builder.join_using( &into_logical_plan!(join.right, ctx, extension_codec)?, @@ -588,7 +636,7 @@ impl AsLogicalPlan for LogicalPlanNode { .collect::>()?; if input_plans.len() < 2 { - return Err( DataFusionError::General(String::from( + return Err( DataFusionError::Internal(String::from( "Protobuf deserialization error, Union was require at least two input.", ))); } @@ -726,7 +774,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } else { - Err(DataFusionError::General(format!( + Err(DataFusionError::Internal(format!( "logical plan to_proto unsupported table provider {:?}", source ))) @@ -854,14 +902,9 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Subquery(_) => { - // note that the ballista and datafusion proto files need refactoring to allow - // LogicalExprNode to reference a LogicalPlanNode - // see https://github.com/apache/arrow-datafusion/issues/2338 - Err(DataFusionError::NotImplemented( - "Ballista does not support subqueries".to_string(), - )) - } + LogicalPlan::Subquery(_) => Err(DataFusionError::NotImplemented( + "LogicalPlan serde is not yet implemented for subqueries".to_string(), + )), LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( @@ -1136,388 +1179,11 @@ impl AsLogicalPlan for LogicalPlanNode { }) } LogicalPlan::CreateMemoryTable(_) => Err(proto_error( - "Error converting CreateMemoryTable. Not yet supported in Ballista", + "LogicalPlan serde is not yet implemented for CreateMemoryTable", )), LogicalPlan::DropTable(_) => Err(proto_error( - "Error converting DropTable. Not yet supported in Ballista", + "LogicalPlan serde is not yet implemented for DropTable", )), } } } - -#[cfg(test)] -mod roundtrip_tests { - - use super::super::{super::error::Result, protobuf}; - use crate::serde::{AsLogicalPlan, BallistaCodec}; - use async_trait::async_trait; - use core::panic; - use datafusion::common::DFSchemaRef; - use datafusion::logical_plan::source_as_provider; - use datafusion::{ - arrow::datatypes::{DataType, Field, Schema}, - datafusion_data_access::{ - self, - object_store::{FileMetaStream, ListEntryStream, ObjectReader, ObjectStore}, - SizedFile, - }, - datasource::listing::ListingTable, - logical_plan::{ - col, CreateExternalTable, Expr, FileType, LogicalPlan, LogicalPlanBuilder, - Repartition, ToDFSchema, - }, - prelude::*, - }; - use std::io; - use std::sync::Arc; - - #[derive(Debug)] - struct TestObjectStore {} - - #[async_trait] - impl ObjectStore for TestObjectStore { - async fn list_file( - &self, - _prefix: &str, - ) -> datafusion_data_access::Result { - Err(io::Error::new( - io::ErrorKind::Unsupported, - "this is only a test object store".to_string(), - )) - } - - async fn list_dir( - &self, - _prefix: &str, - _delimiter: Option, - ) -> datafusion_data_access::Result { - Err(io::Error::new( - io::ErrorKind::Unsupported, - "this is only a test object store".to_string(), - )) - } - - fn file_reader( - &self, - _file: SizedFile, - ) -> datafusion_data_access::Result> { - Err(io::Error::new( - io::ErrorKind::Unsupported, - "this is only a test object store".to_string(), - )) - } - } - - // Given a identity of a LogicalPlan converts it to protobuf and back, using debug formatting to test equality. - macro_rules! roundtrip_test { - ($initial_struct:ident, $proto_type:ty, $struct_type:ty) => { - let proto: $proto_type = (&$initial_struct).try_into()?; - - let round_trip: $struct_type = (&proto).try_into()?; - - assert_eq!( - format!("{:?}", $initial_struct), - format!("{:?}", round_trip) - ); - }; - ($initial_struct:ident, $struct_type:ty) => { - roundtrip_test!($initial_struct, protobuf::LogicalPlanNode, $struct_type); - }; - ($initial_struct:ident) => { - let ctx = SessionContext::new(); - let codec: BallistaCodec< - protobuf::LogicalPlanNode, - protobuf::PhysicalPlanNode, - > = BallistaCodec::default(); - let proto: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - &$initial_struct, - codec.logical_extension_codec(), - ) - .expect("from logical plan"); - let round_trip: LogicalPlan = proto - .try_into_logical_plan(&ctx, codec.logical_extension_codec()) - .expect("to logical plan"); - - assert_eq!( - format!("{:?}", $initial_struct), - format!("{:?}", round_trip) - ); - }; - ($initial_struct:ident, $ctx:ident) => { - let codec: BallistaCodec< - protobuf::LogicalPlanNode, - protobuf::PhysicalPlanNode, - > = BallistaCodec::default(); - let proto: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan(&$initial_struct) - .expect("from logical plan"); - let round_trip: LogicalPlan = proto - .try_into_logical_plan(&$ctx, codec.logical_extension_codec()) - .expect("to logical plan"); - - assert_eq!( - format!("{:?}", $initial_struct), - format!("{:?}", round_trip) - ); - }; - } - - #[tokio::test] - async fn roundtrip_repartition() -> Result<()> { - use datafusion::logical_plan::Partitioning; - - let test_partition_counts = [usize::MIN, usize::MAX, 43256]; - - let test_expr: Vec = - vec![col("c1") + col("c2"), Expr::Literal((4.0).into())]; - - let plan = std::sync::Arc::new( - test_scan_csv("employee.csv", Some(vec![3, 4])) - .await? - .sort(vec![col("salary")])? - .build()?, - ); - - for partition_count in test_partition_counts.iter() { - let rr_repartition = Partitioning::RoundRobinBatch(*partition_count); - - let roundtrip_plan = LogicalPlan::Repartition(Repartition { - input: plan.clone(), - partitioning_scheme: rr_repartition, - }); - - roundtrip_test!(roundtrip_plan); - - let h_repartition = Partitioning::Hash(test_expr.clone(), *partition_count); - - let roundtrip_plan = LogicalPlan::Repartition(Repartition { - input: plan.clone(), - partitioning_scheme: h_repartition, - }); - - roundtrip_test!(roundtrip_plan); - - let no_expr_hrepartition = Partitioning::Hash(Vec::new(), *partition_count); - - let roundtrip_plan = LogicalPlan::Repartition(Repartition { - input: plan.clone(), - partitioning_scheme: no_expr_hrepartition, - }); - - roundtrip_test!(roundtrip_plan); - } - - Ok(()) - } - - #[test] - fn roundtrip_create_external_table() -> Result<()> { - let schema = test_schema(); - - let df_schema_ref = schema.to_dfschema_ref()?; - - let filetypes: [FileType; 4] = [ - FileType::NdJson, - FileType::Parquet, - FileType::CSV, - FileType::Avro, - ]; - - for file in filetypes.iter() { - let create_table_node = - LogicalPlan::CreateExternalTable(CreateExternalTable { - schema: df_schema_ref.clone(), - name: String::from("TestName"), - location: String::from("employee.csv"), - file_type: *file, - has_header: true, - delimiter: ',', - table_partition_cols: vec![], - if_not_exists: false, - }); - - roundtrip_test!(create_table_node); - } - - Ok(()) - } - - #[tokio::test] - async fn roundtrip_analyze() -> Result<()> { - let verbose_plan = test_scan_csv("employee.csv", Some(vec![3, 4])) - .await? - .sort(vec![col("salary")])? - .explain(true, true)? - .build()?; - - let plan = test_scan_csv("employee.csv", Some(vec![3, 4])) - .await? - .sort(vec![col("salary")])? - .explain(false, true)? - .build()?; - - roundtrip_test!(plan); - - roundtrip_test!(verbose_plan); - - Ok(()) - } - - #[tokio::test] - async fn roundtrip_explain() -> Result<()> { - let verbose_plan = test_scan_csv("employee.csv", Some(vec![3, 4])) - .await? - .sort(vec![col("salary")])? - .explain(true, false)? - .build()?; - - let plan = test_scan_csv("employee.csv", Some(vec![3, 4])) - .await? - .sort(vec![col("salary")])? - .explain(false, false)? - .build()?; - - roundtrip_test!(plan); - - roundtrip_test!(verbose_plan); - - Ok(()) - } - - #[tokio::test] - async fn roundtrip_join() -> Result<()> { - let scan_plan = test_scan_csv("employee1", Some(vec![0, 3, 4])) - .await? - .build()?; - - let plan = test_scan_csv("employee2", Some(vec![0, 3, 4])) - .await? - .join(&scan_plan, JoinType::Inner, (vec!["id"], vec!["id"]))? - .build()?; - - roundtrip_test!(plan); - Ok(()) - } - - #[tokio::test] - async fn roundtrip_sort() -> Result<()> { - let plan = test_scan_csv("employee.csv", Some(vec![3, 4])) - .await? - .sort(vec![col("salary")])? - .build()?; - roundtrip_test!(plan); - - Ok(()) - } - - #[tokio::test] - async fn roundtrip_empty_relation() -> Result<()> { - let plan_false = LogicalPlanBuilder::empty(false).build()?; - - roundtrip_test!(plan_false); - - let plan_true = LogicalPlanBuilder::empty(true).build()?; - - roundtrip_test!(plan_true); - - Ok(()) - } - - #[tokio::test] - async fn roundtrip_logical_plan() -> Result<()> { - let plan = test_scan_csv("employee.csv", Some(vec![3, 4])) - .await? - .aggregate(vec![col("state")], vec![max(col("salary"))])? - .build()?; - - roundtrip_test!(plan); - - Ok(()) - } - - #[ignore] // see https://github.com/apache/arrow-datafusion/issues/2546 - #[tokio::test] - async fn roundtrip_logical_plan_custom_ctx() -> Result<()> { - let ctx = SessionContext::new(); - let codec: BallistaCodec = - BallistaCodec::default(); - let custom_object_store = Arc::new(TestObjectStore {}); - ctx.runtime_env() - .register_object_store("test", custom_object_store.clone()); - - let (os, uri) = ctx.runtime_env().object_store("test://foo.csv")?; - assert_eq!("TestObjectStore", &format!("{:?}", os)); - assert_eq!("foo.csv", uri); - - let schema = test_schema(); - let plan = ctx - .read_csv( - "test://employee.csv", - CsvReadOptions::new().schema(&schema).has_header(true), - ) - .await? - .to_logical_plan()?; - - let proto: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - &plan, - codec.logical_extension_codec(), - ) - .expect("from logical plan"); - let round_trip: LogicalPlan = proto - .try_into_logical_plan(&ctx, codec.logical_extension_codec()) - .expect("to logical plan"); - - assert_eq!(format!("{:?}", plan), format!("{:?}", round_trip)); - - let round_trip_store = match round_trip { - LogicalPlan::TableScan(scan) => { - let source = source_as_provider(&scan.source)?; - match source.as_ref().as_any().downcast_ref::() { - Some(listing_table) => { - format!("{:?}", listing_table.object_store()) - } - _ => panic!("expected a ListingTable"), - } - } - _ => panic!("expected a TableScan"), - }; - - assert_eq!(round_trip_store, format!("{:?}", custom_object_store)); - - Ok(()) - } - - fn test_schema() -> Schema { - Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("first_name", DataType::Utf8, false), - Field::new("last_name", DataType::Utf8, false), - Field::new("state", DataType::Utf8, false), - Field::new("salary", DataType::Int32, false), - ]) - } - - async fn test_scan_csv( - table_name: &str, - projection: Option>, - ) -> Result { - let schema = test_schema(); - let ctx = SessionContext::new(); - let options = CsvReadOptions::new().schema(&schema); - let df = ctx.read_csv(table_name, options).await?; - let plan = match df.to_logical_plan()? { - LogicalPlan::TableScan(ref scan) => { - let mut scan = scan.clone(); - scan.projection = projection; - let mut projected_schema = scan.projected_schema.as_ref().clone(); - projected_schema = projected_schema.replace_qualifier(table_name); - scan.projected_schema = DFSchemaRef::new(projected_schema); - LogicalPlan::TableScan(scan) - } - _ => unimplemented!(), - }; - Ok(LogicalPlanBuilder::from(plan)) - } -} diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index f02a69c5f8b66..91b62a6754a82 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -424,7 +424,7 @@ impl TryFrom for protobuf::FileType { _x if _x == FileType::Parquet as i32 => Ok(FileType::Parquet), _x if _x == FileType::Csv as i32 => Ok(FileType::Csv), _x if _x == FileType::Avro as i32 => Ok(FileType::Avro), - invalid => Err(DataFusionError::General(format!( + invalid => Err(DataFusionError::Internal(format!( "Attempted to convert invalid i32 to protobuf::Filetype: {}", invalid ))), From d23347e1be1521fbb73f8e3e7d3dd4c799bc2d0e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 May 2022 09:49:31 -0600 Subject: [PATCH 03/11] remove unnecessary map_err, include README --- datafusion/proto/Cargo.toml | 3 + datafusion/proto/README.md | 17 ++++++ datafusion/proto/src/from_proto.rs | 13 ----- datafusion/proto/src/lib.rs | 5 ++ datafusion/proto/src/logical_plan.rs | 82 +++++++++++++--------------- datafusion/proto/src/to_proto.rs | 25 ++------- 6 files changed, 68 insertions(+), 77 deletions(-) diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index e19517f57436d..afb9f7c5ab18f 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -44,5 +44,8 @@ datafusion-expr = { path = "../expr", version = "8.0.0" } prost = "0.10" tokio = "1.18" +[dev-dependencies] +doc-comment = "0.3" + [build-dependencies] tonic-build = { version = "0.7" } diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index b928652e9ec0d..acd319b834480 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -23,4 +23,21 @@ This crate is a submodule of DataFusion that provides a protocol buffer format for representing query plans and expressions. +The following example demonstrates serializing and deserializing a logical expression. + +``` rust +use datafusion_expr::{col, lit, Expr}; +use datafusion_proto::bytes::Serializeable; + +// Create a new `Expr` a < 32 +let expr = col("a").lt(lit(5i32)); + +// Convert it to an opaque form +let bytes = expr.to_bytes().unwrap(); + +// Decode bytes from somewhere (over network, etc.) +let decoded_expr = Expr::from_bytes(&bytes).unwrap(); +assert_eq!(expr, decoded_expr); +``` + [df]: https://crates.io/crates/datafusion diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 6ccd4025b1aff..fd72db4edac46 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -187,19 +187,6 @@ impl TryFrom<&protobuf::DfField> for DFField { } } -#[allow(clippy::from_over_into)] -impl Into for protobuf::FileType { - fn into(self) -> datafusion::logical_plan::FileType { - use datafusion::logical_plan::FileType; - match self { - protobuf::FileType::NdJson => FileType::NdJson, - protobuf::FileType::Parquet => FileType::Parquet, - protobuf::FileType::Csv => FileType::CSV, - protobuf::FileType::Avro => FileType::Avro, - } - } -} - impl From for WindowFrameUnits { fn from(units: protobuf::WindowFrameUnits) -> Self { match units { diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index dcf91ee7e3b5d..fed01dd5e27ad 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! Serde code for logical plans and expressions. + use datafusion_common::DataFusionError; // include the generated protobuf source as a submodule @@ -28,6 +30,9 @@ pub mod from_proto; pub mod logical_plan; pub mod to_proto; +#[cfg(doctest)] +doc_comment::doctest!("../README.md", readme_example_test); + impl From for DataFusionError { fn from(e: from_proto::Error) -> Self { DataFusionError::Plan(e.to_string()) diff --git a/datafusion/proto/src/logical_plan.rs b/datafusion/proto/src/logical_plan.rs index f2c4ec531b3d5..4993cfdce5dd5 100644 --- a/datafusion/proto/src/logical_plan.rs +++ b/datafusion/proto/src/logical_plan.rs @@ -178,6 +178,36 @@ macro_rules! convert_box_required { }}; } +#[allow(clippy::from_over_into)] +impl Into for protobuf::FileType { + fn into(self) -> datafusion::logical_plan::FileType { + use datafusion::logical_plan::FileType; + match self { + protobuf::FileType::NdJson => FileType::NdJson, + protobuf::FileType::Parquet => FileType::Parquet, + protobuf::FileType::Csv => FileType::CSV, + protobuf::FileType::Avro => FileType::Avro, + } + } +} + +impl TryFrom for protobuf::FileType { + type Error = DataFusionError; + fn try_from(value: i32) -> Result { + use protobuf::FileType; + match value { + _x if _x == FileType::NdJson as i32 => Ok(FileType::NdJson), + _x if _x == FileType::Parquet as i32 => Ok(FileType::Parquet), + _x if _x == FileType::Csv as i32 => Ok(FileType::Csv), + _x if _x == FileType::Avro as i32 => Ok(FileType::Avro), + invalid => Err(DataFusionError::Internal(format!( + "Attempted to convert invalid i32 to protobuf::Filetype: {}", + invalid + ))), + } + } +} + impl From for JoinType { fn from(t: protobuf::JoinType) -> Self { match t { @@ -277,9 +307,7 @@ impl AsLogicalPlan for LogicalPlanNode { .collect::, _>>() .map_err(|e| e.into()) }?; - LogicalPlanBuilder::values(values)? - .build() - .map_err(|e| e.into()) + LogicalPlanBuilder::values(values)?.build() } LogicalPlanType::Projection(projection) => { let input: LogicalPlan = @@ -299,7 +327,6 @@ impl AsLogicalPlan for LogicalPlanNode { }), )? .build() - .map_err(|e| e.into()) } LogicalPlanType::Selection(selection) => { let input: LogicalPlan = @@ -313,10 +340,7 @@ impl AsLogicalPlan for LogicalPlanNode { DataFusionError::Internal("expression required".to_string()) })?; // .try_into()?; - LogicalPlanBuilder::from(input) - .filter(expr)? - .build() - .map_err(|e| e.into()) + LogicalPlanBuilder::from(input).filter(expr)?.build() } LogicalPlanType::Window(window) => { let input: LogicalPlan = @@ -326,10 +350,7 @@ impl AsLogicalPlan for LogicalPlanNode { .iter() .map(|expr| parse_expr(expr, ctx)) .collect::, _>>()?; - LogicalPlanBuilder::from(input) - .window(window_expr)? - .build() - .map_err(|e| e.into()) + LogicalPlanBuilder::from(input).window(window_expr)?.build() } LogicalPlanType::Aggregate(aggregate) => { let input: LogicalPlan = @@ -347,7 +368,6 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanBuilder::from(input) .aggregate(group_expr, aggr_expr)? .build() - .map_err(|e| e.into()) } LogicalPlanType::ListingScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; @@ -399,16 +419,7 @@ impl AsLogicalPlan for LogicalPlanNode { target_partitions: scan.target_partitions as usize, }; - let object_store = ctx - .runtime_env() - .object_store(scan.path.as_str()) - .map_err(|e| { - DataFusionError::NotImplemented(format!( - "No object store is registered for path {}: {:?}", - scan.path, e - )) - })? - .0; + let object_store = ctx.runtime_env().object_store(scan.path.as_str())?.0; println!( "Found object store {:?} for path {}", @@ -429,7 +440,6 @@ impl AsLogicalPlan for LogicalPlanNode { filters, )? .build() - .map_err(|e| e.into()) } LogicalPlanType::Sort(sort) => { let input: LogicalPlan = @@ -439,10 +449,7 @@ impl AsLogicalPlan for LogicalPlanNode { .iter() .map(|expr| parse_expr(expr, ctx)) .collect::, _>>()?; - LogicalPlanBuilder::from(input) - .sort(sort_expr)? - .build() - .map_err(|e| e.into()) + LogicalPlanBuilder::from(input).sort(sort_expr)?.build() } LogicalPlanType::Repartition(repartition) => { use datafusion::logical_plan::Partitioning; @@ -474,12 +481,9 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanBuilder::from(input) .repartition(partitioning_scheme)? .build() - .map_err(|e| e.into()) } LogicalPlanType::EmptyRelation(empty_relation) => { - LogicalPlanBuilder::empty(empty_relation.produce_one_row) - .build() - .map_err(|e| e.into()) + LogicalPlanBuilder::empty(empty_relation.produce_one_row).build() } LogicalPlanType::CreateExternalTable(create_extern_table) => { let pb_schema = (create_extern_table.schema.clone()).ok_or_else(|| { @@ -551,7 +555,6 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanBuilder::from(input) .explain(analyze.verbose, true)? .build() - .map_err(|e| e.into()) } LogicalPlanType::Explain(explain) => { let input: LogicalPlan = @@ -559,7 +562,6 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanBuilder::from(input) .explain(explain.verbose, false)? .build() - .map_err(|e| e.into()) } LogicalPlanType::SubqueryAlias(aliased_relation) => { let input: LogicalPlan = @@ -567,7 +569,6 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanBuilder::from(input) .alias(&aliased_relation.alias)? .build() - .map_err(|e| e.into()) } LogicalPlanType::Limit(limit) => { let input: LogicalPlan = @@ -575,7 +576,6 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanBuilder::from(input) .limit(limit.limit as usize)? .build() - .map_err(|e| e.into()) } LogicalPlanType::Offset(offset) => { let input: LogicalPlan = @@ -583,7 +583,6 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanBuilder::from(input) .offset(offset.offset as usize)? .build() - .map_err(|e| e.into()) } LogicalPlanType::Join(join) => { let left_keys: Vec = @@ -626,7 +625,7 @@ impl AsLogicalPlan for LogicalPlanNode { )?, }; - builder.build().map_err(|e| e.into()) + builder.build() } LogicalPlanType::Union(union) => { let mut input_plans: Vec = union @@ -645,16 +644,13 @@ impl AsLogicalPlan for LogicalPlanNode { for plan in input_plans { builder = builder.union(plan)?; } - builder.build().map_err(|e| e.into()) + builder.build() } LogicalPlanType::CrossJoin(crossjoin) => { let left = into_logical_plan!(crossjoin.left, ctx, extension_codec)?; let right = into_logical_plan!(crossjoin.right, ctx, extension_codec)?; - LogicalPlanBuilder::from(left) - .cross_join(&right)? - .build() - .map_err(|e| e.into()) + LogicalPlanBuilder::from(left).cross_join(&right)?.build() } LogicalPlanType::Extension(LogicalExtensionNode { node, inputs }) => { let input_plans: Vec = inputs diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 91b62a6754a82..7aa4278b39a49 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -30,11 +30,11 @@ use crate::protobuf::{ use arrow::datatypes::{ DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode, }; -use datafusion_common::{Column, DFField, DFSchemaRef, DataFusionError, ScalarValue}; +use datafusion_common::{Column, DFField, DFSchemaRef, ScalarValue}; use datafusion_expr::{ - logical_plan::{PlanType, StringifiedPlan}, - AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, Expr, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, + BuiltInWindowFunction, BuiltinScalarFunction, Expr, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunction, }; #[derive(Debug)] @@ -415,23 +415,6 @@ impl From for protobuf::WindowFrame { } } -impl TryFrom for protobuf::FileType { - type Error = DataFusionError; - fn try_from(value: i32) -> Result { - use protobuf::FileType; - match value { - _x if _x == FileType::NdJson as i32 => Ok(FileType::NdJson), - _x if _x == FileType::Parquet as i32 => Ok(FileType::Parquet), - _x if _x == FileType::Csv as i32 => Ok(FileType::Csv), - _x if _x == FileType::Avro as i32 => Ok(FileType::Avro), - invalid => Err(DataFusionError::Internal(format!( - "Attempted to convert invalid i32 to protobuf::Filetype: {}", - invalid - ))), - } - } -} - impl TryFrom<&Expr> for protobuf::LogicalExprNode { type Error = Error; From b4fc9c70d7ea33fb204952a8dc8f31ab8f87b367 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 May 2022 10:27:47 -0600 Subject: [PATCH 04/11] implement round-trip test for logical plan serde --- datafusion/proto/README.md | 10 ++++ datafusion/proto/src/bytes/mod.rs | 79 ++++++++++++++++++++++++++++-- datafusion/proto/src/lib.rs | 18 ++++++- datafusion/proto/testdata/test.csv | 3 ++ 4 files changed, 104 insertions(+), 6 deletions(-) create mode 100644 datafusion/proto/testdata/test.csv diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index acd319b834480..9d9cc79e68c94 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -40,4 +40,14 @@ let decoded_expr = Expr::from_bytes(&bytes).unwrap(); assert_eq!(expr, decoded_expr); ``` +The following example demonstrates serializing and deserializing a logical plan. + +``` rust +let ctx = SessionContext::new(); +ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()).await.unwrap(); +let plan = ctx.table("t1").unwrap().to_logical_plan().unwrap(); +let bytes = logical_plan_to_bytes(&plan).unwrap(); +let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).unwrap(); +``` + [df]: https://crates.io/crates/datafusion diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 1781756df61dc..37374b3eff5a4 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -16,14 +16,22 @@ // under the License. //! Serialization / Deserialization to Bytes -use crate::{from_proto::parse_expr, protobuf}; +use crate::{ + from_proto::parse_expr, + logical_plan::{AsLogicalPlan, LogicalExtensionCodec}, + protobuf, +}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Expr; -use prost::{bytes::BytesMut, Message}; +use datafusion_expr::{Expr, LogicalPlan}; +use prost::{ + bytes::{Bytes, BytesMut}, + Message, +}; // Reexport Bytes which appears in the API use datafusion::logical_plan::FunctionRegistry; -pub use prost::bytes::Bytes; +use datafusion::prelude::SessionContext; +use datafusion_expr::logical_plan::Extension; mod registry; @@ -93,6 +101,69 @@ impl Serializeable for Expr { } } +/// Serialize a LogicalPlan as bytes +pub fn logical_plan_to_bytes(plan: &LogicalPlan) -> Result { + let extension_codec = DefaultExtensionCodec {}; + logical_plan_to_bytes_with_extension_codec(plan, &extension_codec) +} + +/// Serialize a LogicalPlan as bytes, using the provided extension codec +pub fn logical_plan_to_bytes_with_extension_codec( + plan: &LogicalPlan, + extension_codec: &dyn LogicalExtensionCodec, +) -> Result { + let protobuf = + protobuf::LogicalPlanNode::try_from_logical_plan(plan, extension_codec)?; + let mut buffer = BytesMut::new(); + protobuf.encode(&mut buffer).map_err(|e| { + DataFusionError::Plan(format!("Error encoding protobuf as bytes: {}", e)) + })?; + Ok(buffer.into()) +} + +/// Deserialize a LogicalPlan from bytes +pub fn logical_plan_from_bytes( + bytes: &[u8], + ctx: &SessionContext, +) -> Result { + let extension_codec = DefaultExtensionCodec {}; + logical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec) +} + +/// Deserialize a LogicalPlan from bytes +pub fn logical_plan_from_bytes_with_extension_codec( + bytes: &[u8], + ctx: &SessionContext, + extension_codec: &dyn LogicalExtensionCodec, +) -> Result { + let protobuf = protobuf::LogicalPlanNode::decode(bytes).map_err(|e| { + DataFusionError::Plan(format!("Error decoding expr as protobuf: {}", e)) + })?; + protobuf.try_into_logical_plan(ctx, extension_codec) +} + +#[derive(Debug)] +struct DefaultExtensionCodec {} + +impl LogicalExtensionCodec for DefaultExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &SessionContext, + ) -> Result { + Err(DataFusionError::NotImplemented( + "No extension codec provided".to_string(), + )) + } + + fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { + Err(DataFusionError::NotImplemented( + "No extension codec provided".to_string(), + )) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index fed01dd5e27ad..e51896a33cb9b 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -49,18 +49,20 @@ impl From for DataFusionError { mod roundtrip_tests { use super::from_proto::parse_expr; use super::protobuf; + use crate::bytes::{logical_plan_from_bytes, logical_plan_to_bytes}; use arrow::{ array::ArrayRef, datatypes::{DataType, Field, IntervalUnit, TimeUnit, UnionMode}, }; use datafusion::logical_plan::create_udaf; use datafusion::physical_plan::functions::make_scalar_function; - use datafusion::prelude::{create_udf, SessionContext}; - use datafusion_common::ScalarValue; + use datafusion::prelude::{create_udf, CsvReadOptions, SessionContext}; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ col, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt, Expr, Volatility, }; + use std::fmt::Debug; use std::sync::Arc; // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test @@ -82,6 +84,18 @@ mod roundtrip_tests { Box::new(Field::new(name, dt, nullable)) } + #[tokio::test] + async fn roundtrip_logical_plan() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) + .await?; + let plan = ctx.table("t1")?.to_logical_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); + Ok(()) + } + #[test] fn scalar_values_error_serialization() { let should_fail_on_seralize: Vec = vec![ diff --git a/datafusion/proto/testdata/test.csv b/datafusion/proto/testdata/test.csv new file mode 100644 index 0000000000000..f0e8c98e67acf --- /dev/null +++ b/datafusion/proto/testdata/test.csv @@ -0,0 +1,3 @@ +a,b +1,2 +3,4 \ No newline at end of file From a19b73376dbda7d5301f99775a746967ab34f575 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 May 2022 10:34:02 -0600 Subject: [PATCH 05/11] examples --- datafusion/proto/examples/expr_serde.rs | 34 +++++++++++++++++++++++++ datafusion/proto/examples/plan_serde.rs | 33 ++++++++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 datafusion/proto/examples/expr_serde.rs create mode 100644 datafusion/proto/examples/plan_serde.rs diff --git a/datafusion/proto/examples/expr_serde.rs b/datafusion/proto/examples/expr_serde.rs new file mode 100644 index 0000000000000..81f0f7da70320 --- /dev/null +++ b/datafusion/proto/examples/expr_serde.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::Result; +use datafusion_expr::{col, lit, Expr}; +use datafusion_proto::bytes::Serializeable; + +#[tokio::main] +async fn main() -> Result<()> { + // Create a new `Expr` a < 32 + let expr = col("a").lt(lit(5i32)); + + // Convert it to an opaque form + let bytes = expr.to_bytes().unwrap(); + + // Decode bytes from somewhere (over network, etc.) + let decoded_expr = Expr::from_bytes(&bytes).unwrap(); + assert_eq!(expr, decoded_expr); + Ok(()) +} diff --git a/datafusion/proto/examples/plan_serde.rs b/datafusion/proto/examples/plan_serde.rs new file mode 100644 index 0000000000000..d7dd83f53846c --- /dev/null +++ b/datafusion/proto/examples/plan_serde.rs @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::prelude::*; +use datafusion_common::Result; +use datafusion_proto::bytes::{logical_plan_from_bytes, logical_plan_to_bytes}; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) + .await + .unwrap(); + let plan = ctx.table("t1").unwrap().to_logical_plan().unwrap(); + let bytes = logical_plan_to_bytes(&plan).unwrap(); + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).unwrap(); + assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); + Ok(()) +} From 3c3bb6dd8b16f44bcd35260fe9542036f434af51 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 May 2022 10:38:50 -0600 Subject: [PATCH 06/11] update README --- datafusion/proto/README.md | 45 +++++++++++++++++-------- datafusion/proto/examples/expr_serde.rs | 7 ++-- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index 9d9cc79e68c94..32912aff99e6b 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -23,31 +23,48 @@ This crate is a submodule of DataFusion that provides a protocol buffer format for representing query plans and expressions. -The following example demonstrates serializing and deserializing a logical expression. +## Serializing Expressions + +Based on [examples/expr_serde.rs](examples/expr_serde.rs) ``` rust +use datafusion_common::Result; use datafusion_expr::{col, lit, Expr}; use datafusion_proto::bytes::Serializeable; -// Create a new `Expr` a < 32 -let expr = col("a").lt(lit(5i32)); +fn main() -> Result<()> { + // Create a new `Expr` a < 32 + let expr = col("a").lt(lit(5i32)); -// Convert it to an opaque form -let bytes = expr.to_bytes().unwrap(); + // Convert it to an opaque form + let bytes = expr.to_bytes()?; -// Decode bytes from somewhere (over network, etc.) -let decoded_expr = Expr::from_bytes(&bytes).unwrap(); -assert_eq!(expr, decoded_expr); + // Decode bytes from somewhere (over network, etc.) + let decoded_expr = Expr::from_bytes(&bytes)?; + assert_eq!(expr, decoded_expr); + Ok(()) +} ``` -The following example demonstrates serializing and deserializing a logical plan. +## Serializing Plans + +Based on [examples/plan_serde.rs](examples/plan_serde.rs) ``` rust -let ctx = SessionContext::new(); -ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()).await.unwrap(); -let plan = ctx.table("t1").unwrap().to_logical_plan().unwrap(); -let bytes = logical_plan_to_bytes(&plan).unwrap(); -let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).unwrap(); +use datafusion::prelude::*; +use datafusion_common::Result; +use datafusion_proto::bytes::{logical_plan_from_bytes, logical_plan_to_bytes}; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()).await.unwrap(); + let plan = ctx.table("t1").unwrap().to_logical_plan().unwrap(); + let bytes = logical_plan_to_bytes(&plan).unwrap(); + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).unwrap(); + assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); + Ok(()) +} ``` [df]: https://crates.io/crates/datafusion diff --git a/datafusion/proto/examples/expr_serde.rs b/datafusion/proto/examples/expr_serde.rs index 81f0f7da70320..9da64f87e2b1e 100644 --- a/datafusion/proto/examples/expr_serde.rs +++ b/datafusion/proto/examples/expr_serde.rs @@ -19,16 +19,15 @@ use datafusion_common::Result; use datafusion_expr::{col, lit, Expr}; use datafusion_proto::bytes::Serializeable; -#[tokio::main] -async fn main() -> Result<()> { +fn main() -> Result<()> { // Create a new `Expr` a < 32 let expr = col("a").lt(lit(5i32)); // Convert it to an opaque form - let bytes = expr.to_bytes().unwrap(); + let bytes = expr.to_bytes()?; // Decode bytes from somewhere (over network, etc.) - let decoded_expr = Expr::from_bytes(&bytes).unwrap(); + let decoded_expr = Expr::from_bytes(&bytes)?; assert_eq!(expr, decoded_expr); Ok(()) } From 15592a7250e87ffef7c5e764111e08aa425e6f7d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 May 2022 10:41:23 -0600 Subject: [PATCH 07/11] prettier --- datafusion/proto/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index 32912aff99e6b..ccf8454afeb57 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -27,7 +27,7 @@ This crate is a submodule of DataFusion that provides a protocol buffer format f Based on [examples/expr_serde.rs](examples/expr_serde.rs) -``` rust +```rust use datafusion_common::Result; use datafusion_expr::{col, lit, Expr}; use datafusion_proto::bytes::Serializeable; @@ -50,7 +50,7 @@ fn main() -> Result<()> { Based on [examples/plan_serde.rs](examples/plan_serde.rs) -``` rust +```rust use datafusion::prelude::*; use datafusion_common::Result; use datafusion_proto::bytes::{logical_plan_from_bytes, logical_plan_to_bytes}; From 96f564f766d966f33f8e999ce9685ea64f6bb1ae Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 May 2022 11:55:34 -0600 Subject: [PATCH 08/11] re-export proto crate --- datafusion/core/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 600e24fb8f187..f0ca84af99144 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -231,9 +231,9 @@ pub use datafusion_common as common; pub use datafusion_data_access; pub use datafusion_expr as logical_expr; pub use datafusion_physical_expr as physical_expr; -pub use datafusion_sql as sql; - +pub use datafusion_proto as proto; pub use datafusion_row as row; +pub use datafusion_sql as sql; #[cfg(feature = "jit")] pub use datafusion_jit as jit; From 118026efc07b22a9533d1e153afe485125261e2b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 May 2022 12:00:48 -0600 Subject: [PATCH 09/11] Revert "re-export proto crate" This reverts commit 96f564f766d966f33f8e999ce9685ea64f6bb1ae. --- datafusion/core/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index f0ca84af99144..600e24fb8f187 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -231,10 +231,10 @@ pub use datafusion_common as common; pub use datafusion_data_access; pub use datafusion_expr as logical_expr; pub use datafusion_physical_expr as physical_expr; -pub use datafusion_proto as proto; -pub use datafusion_row as row; pub use datafusion_sql as sql; +pub use datafusion_row as row; + #[cfg(feature = "jit")] pub use datafusion_jit as jit; From 3f6c70b66cdfd405e950f189de618f5631f87350 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 29 May 2022 06:04:17 -0600 Subject: [PATCH 10/11] remove unwrap from examples --- datafusion/proto/README.md | 10 ++++++---- datafusion/proto/examples/plan_serde.rs | 9 ++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index ccf8454afeb57..a3878447e042e 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -58,10 +58,12 @@ use datafusion_proto::bytes::{logical_plan_from_bytes, logical_plan_to_bytes}; #[tokio::main] async fn main() -> Result<()> { let ctx = SessionContext::new(); - ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()).await.unwrap(); - let plan = ctx.table("t1").unwrap().to_logical_plan().unwrap(); - let bytes = logical_plan_to_bytes(&plan).unwrap(); - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).unwrap(); + ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) + .await + ?; + let plan = ctx.table("t1")?.to_logical_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); Ok(()) } diff --git a/datafusion/proto/examples/plan_serde.rs b/datafusion/proto/examples/plan_serde.rs index d7dd83f53846c..d98d88b2a46a6 100644 --- a/datafusion/proto/examples/plan_serde.rs +++ b/datafusion/proto/examples/plan_serde.rs @@ -23,11 +23,10 @@ use datafusion_proto::bytes::{logical_plan_from_bytes, logical_plan_to_bytes}; async fn main() -> Result<()> { let ctx = SessionContext::new(); ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) - .await - .unwrap(); - let plan = ctx.table("t1").unwrap().to_logical_plan().unwrap(); - let bytes = logical_plan_to_bytes(&plan).unwrap(); - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).unwrap(); + .await?; + let plan = ctx.table("t1")?.to_logical_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); Ok(()) } From 7baf3e5cfd8d3b0a82e704324995eed3da2fe862 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 29 May 2022 06:13:35 -0600 Subject: [PATCH 11/11] add test with extension_codec --- datafusion/proto/src/lib.rs | 178 +++++++++++++++++++++++++++++++++++- 1 file changed, 175 insertions(+), 3 deletions(-) diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index e51896a33cb9b..24ac70346f55c 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -49,7 +49,11 @@ impl From for DataFusionError { mod roundtrip_tests { use super::from_proto::parse_expr; use super::protobuf; - use crate::bytes::{logical_plan_from_bytes, logical_plan_to_bytes}; + use crate::bytes::{ + logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, + logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, + }; + use crate::logical_plan::LogicalExtensionCodec; use arrow::{ array::ArrayRef, datatypes::{DataType, Field, IntervalUnit, TimeUnit, UnionMode}, @@ -57,12 +61,17 @@ mod roundtrip_tests { use datafusion::logical_plan::create_udaf; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionContext}; - use datafusion_common::{DataFusionError, ScalarValue}; + use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; + use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode}; use datafusion_expr::{ col, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt, Expr, - Volatility, + LogicalPlan, Volatility, }; + use prost::Message; + use std::any::Any; + use std::fmt; use std::fmt::Debug; + use std::fmt::Formatter; use std::sync::Arc; // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test @@ -86,6 +95,27 @@ mod roundtrip_tests { #[tokio::test] async fn roundtrip_logical_plan() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) + .await?; + let scan = ctx.table("t1")?.to_logical_plan()?; + let topk_plan = LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode::new(3, scan, col("revenue"))), + }); + let extension_codec = TopKExtensionCodec {}; + let bytes = + logical_plan_to_bytes_with_extension_codec(&topk_plan, &extension_codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &extension_codec)?; + assert_eq!( + format!("{:?}", topk_plan), + format!("{:?}", logical_round_trip) + ); + Ok(()) + } + + #[tokio::test] + async fn roundtrip_logical_plan_with_extension() -> Result<(), DataFusionError> { let ctx = SessionContext::new(); ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) .await?; @@ -96,6 +126,148 @@ mod roundtrip_tests { Ok(()) } + pub mod proto { + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct TopKPlanProto { + #[prost(uint64, tag = "1")] + pub k: u64, + + #[prost(message, optional, tag = "2")] + pub expr: ::core::option::Option, + } + + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct TopKExecProto { + #[prost(uint64, tag = "1")] + pub k: u64, + } + } + + struct TopKPlanNode { + k: usize, + input: LogicalPlan, + /// The sort expression (this example only supports a single sort + /// expr) + expr: Expr, + } + + impl TopKPlanNode { + pub fn new(k: usize, input: LogicalPlan, expr: Expr) -> Self { + Self { k, input, expr } + } + } + + impl Debug for TopKPlanNode { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + self.fmt_for_explain(f) + } + } + + impl UserDefinedLogicalNode for TopKPlanNode { + fn as_any(&self) -> &dyn Any { + self + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + /// Schema for TopK is the same as the input + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![self.expr.clone()] + } + + /// For example: `TopK: k=10` + fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "TopK: k={}", self.k) + } + + fn from_template( + &self, + exprs: &[Expr], + inputs: &[LogicalPlan], + ) -> Arc { + assert_eq!(inputs.len(), 1, "input size inconsistent"); + assert_eq!(exprs.len(), 1, "expression size inconsistent"); + Arc::new(TopKPlanNode { + k: self.k, + input: inputs[0].clone(), + expr: exprs[0].clone(), + }) + } + } + + #[derive(Debug)] + pub struct TopKExtensionCodec {} + + impl LogicalExtensionCodec for TopKExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[LogicalPlan], + ctx: &SessionContext, + ) -> Result { + if let Some((input, _)) = inputs.split_first() { + let proto = proto::TopKPlanProto::decode(buf).map_err(|e| { + DataFusionError::Internal(format!( + "failed to decode logical plan: {:?}", + e + )) + })?; + + if let Some(expr) = proto.expr.as_ref() { + let node = TopKPlanNode::new( + proto.k as usize, + input.clone(), + parse_expr(expr, ctx)?, + ); + + Ok(Extension { + node: Arc::new(node), + }) + } else { + Err(DataFusionError::Internal( + "invalid plan, no expr".to_string(), + )) + } + } else { + Err(DataFusionError::Internal( + "invalid plan, no input".to_string(), + )) + } + } + + fn try_encode( + &self, + node: &Extension, + buf: &mut Vec, + ) -> Result<(), DataFusionError> { + if let Some(exec) = node.node.as_any().downcast_ref::() { + let proto = proto::TopKPlanProto { + k: exec.k as u64, + expr: Some((&exec.expr).try_into()?), + }; + + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!( + "failed to encode logical plan: {:?}", + e + )) + })?; + + Ok(()) + } else { + Err(DataFusionError::Internal( + "unsupported plan type".to_string(), + )) + } + } + } + #[test] fn scalar_values_error_serialization() { let should_fail_on_seralize: Vec = vec![